From a17e5f3707a59f979c2e8cc749137be851f5209b Mon Sep 17 00:00:00 2001 From: Ahmet Guzererler Date: Fri, 20 Mar 2026 14:06:50 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20complete=20portfolio=20data=20foundatio?= =?UTF-8?q?n=20=E2=80=94=20psycopg2=20client,=20repository,=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace supabase-py stubs with working psycopg2 implementation using Supabase pooler connection string. Implement full business logic in repository (avg cost basis, cash accounting, trade recording, snapshots). Add 12 unit tests + 4 integration tests (51 total portfolio tests pass). Fix cash_pct bug in models.py, update docs for psycopg2 + pooler pattern. Co-Authored-By: Claude Opus 4.6 --- docs/agent/CURRENT_STATE.md | 15 +- docs/portfolio/00_overview.md | 34 +- docs/portfolio/01_phase1_plan.md | 19 +- tests/portfolio/conftest.py | 20 +- tests/portfolio/test_repository.py | 320 +++++++++++--- tradingagents/portfolio/config.py | 107 +++-- tradingagents/portfolio/models.py | 2 +- tradingagents/portfolio/repository.py | 330 +++++++-------- tradingagents/portfolio/supabase_client.py | 467 +++++++++++++-------- 9 files changed, 816 insertions(+), 498 deletions(-) diff --git a/docs/agent/CURRENT_STATE.md b/docs/agent/CURRENT_STATE.md index b1155699..93555e26 100644 --- a/docs/agent/CURRENT_STATE.md +++ b/docs/agent/CURRENT_STATE.md @@ -1,20 +1,21 @@ # Current Milestone -Daily digest consolidation and Google NotebookLM sync shipped (PR open: `feat/daily-digest-notebooklm`). All analyses now append to a single `daily_digest.md` per day and auto-upload to NotebookLM via `nlm` CLI. Next: PR review and merge. +Portfolio Manager Phase 1 (data foundation) complete and merged. All 4 Supabase tables live, 51 tests passing (including integration tests against live DB). # Recent Progress +- **PR #32 merged**: Portfolio Manager data foundation — models, SQL schema, module scaffolding + - `tradingagents/portfolio/` — full module: models, config, exceptions, supabase_client (psycopg2), report_store, repository + - `migrations/001_initial_schema.sql` — 4 tables (portfolios, holdings, trades, snapshots) with constraints, indexes, triggers + - `tests/portfolio/` — 51 tests: 20 model, 15 report_store, 12 repository unit, 4 integration + - Uses `psycopg2` direct PostgreSQL via Supabase pooler (`aws-1-eu-west-1.pooler.supabase.com:6543`) + - Business logic: avg cost basis, cash accounting, trade recording, snapshots - **PR #22 merged**: Unified report paths, structured observability logging, memory system update - **feat/daily-digest-notebooklm** (shipped): Daily digest consolidation + NotebookLM source sync - - `tradingagents/daily_digest.py` — `append_to_digest()` appends timestamped entries to `reports/daily/{date}/daily_digest.md` - - `tradingagents/notebook_sync.py` — `sync_to_notebooklm()` deletes existing "Daily Trading Digest" source then uploads new content via `nlm source add --text --wait`. - - `tradingagents/report_paths.py` — added `get_digest_path(date)` - - `cli/main.py` — `analyze` and `scan` commands both call digest + sync after each run - - `.env.example` — fixed consistency, removed duplicates, aligned with `NOTEBOOKLM_ID` -- **Verification**: 220+ offline tests passing + 5 new unit tests for `notebook_sync.py` + live integration test passed. # In Progress +- Portfolio Manager Phase 2: Holding Reviewer Agent (next) - Refinement of macro scan synthesis prompts (ongoing) # Active Blockers diff --git a/docs/portfolio/00_overview.md b/docs/portfolio/00_overview.md index 741a2605..17753d2c 100644 --- a/docs/portfolio/00_overview.md +++ b/docs/portfolio/00_overview.md @@ -72,28 +72,26 @@ investment portfolio end-to-end. It performs the following actions in sequence: The `report_path` column in the `portfolios` table points to the daily portfolio subdirectory on disk: `reports/daily/{date}/portfolio/`. -### Data Access Layer: raw `supabase-py` (no ORM) +### Data Access Layer: raw `psycopg2` (no ORM) -The Python code talks to Supabase through the raw `supabase-py` client — **no -ORM** (Prisma, SQLAlchemy, etc.) is used. +The Python code talks to Supabase PostgreSQL directly via `psycopg2` using the +**pooler connection string** (`SUPABASE_CONNECTION_STRING`). No ORM (Prisma, +SQLAlchemy) and no `supabase-py` REST client is used. -**Why not Prisma?** -- `prisma-client-py` requires a Node.js runtime for code generation — an - extra non-Python dependency in a Python-only project. -- Prisma's `prisma migrate` conflicts with Supabase's own SQL migration tooling - (we use `.sql` files in `tradingagents/portfolio/migrations/`). -- 4 tables with straightforward CRUD don't benefit from a code-generated ORM. +**Why `psycopg2` over `supabase-py`?** +- Direct SQL gives full control — transactions, upserts, `RETURNING *`, CTEs. +- No dependency on Supabase's PostgREST schema cache or API key types. +- `psycopg2-binary` is a single pip install with zero non-Python dependencies. +- 4 tables with straightforward CRUD don't benefit from an ORM or REST wrapper. -**Why not SQLAlchemy?** -- Supabase is accessed via PostgREST (HTTP API), not a direct TCP database - connection. SQLAlchemy is designed for direct connections and would bypass - Supabase's Row Level Security. -- Extra dependency overhead for a non-DB-heavy feature. +**Connection:** +- Uses `SUPABASE_CONNECTION_STRING` env var (pooler URI format). +- Passwords with special characters are auto-URL-encoded by `SupabaseClient._fix_dsn()`. +- Typical pooler URI: `postgresql://postgres.:@aws-1-.pooler.supabase.com:6543/postgres` -**`supabase-py` is sufficient because:** -- Its builder-pattern API (`client.table("holdings").select("*").eq(...)`) - covers all needed queries cleanly. -- Our own dataclasses handle type-safety via `to_dict()` / `from_dict()`. +**Why not Prisma / SQLAlchemy?** +- Prisma requires Node.js runtime — extra non-Python dependency. +- SQLAlchemy adds dependency overhead for 4 simple tables. - Plain SQL migration files are readable, versionable, and Supabase-native. > Full rationale: `docs/agent/decisions/012-portfolio-no-orm.md` diff --git a/docs/portfolio/01_phase1_plan.md b/docs/portfolio/01_phase1_plan.md index dd14c2ee..71e82629 100644 --- a/docs/portfolio/01_phase1_plan.md +++ b/docs/portfolio/01_phase1_plan.md @@ -89,14 +89,13 @@ All are optional and default to the values shown: | Env Var | Default | Description | |---------|---------|-------------| -| `SUPABASE_URL` | `""` | Supabase project URL | -| `SUPABASE_KEY` | `""` | Supabase anon/service key | -| `PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports | -| `PM_MAX_POSITIONS` | `15` | Max number of open positions | -| `PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight | -| `PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight | -| `PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve | -| `PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) | +| `SUPABASE_CONNECTION_STRING` | `""` | PostgreSQL pooler connection URI | +| `TRADINGAGENTS_PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports | +| `TRADINGAGENTS_PM_MAX_POSITIONS` | `15` | Max number of open positions | +| `TRADINGAGENTS_PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight | +| `TRADINGAGENTS_PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight | +| `TRADINGAGENTS_PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve | +| `TRADINGAGENTS_PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) | ### Acceptance Criteria @@ -172,7 +171,7 @@ Methods that query a single row raise `PortfolioNotFoundError` when no row is fo - Singleton — only one Supabase connection per process - All public methods fully type-annotated -- Supabase integration tests auto-skip when `SUPABASE_URL` is unset +- Supabase integration tests auto-skip when `SUPABASE_CONNECTION_STRING` is unset --- @@ -328,7 +327,7 @@ require PG functions — deferred to Phase 3+). ### Coverage Target 90 %+ for `models.py` and `report_store.py`. -Integration tests (`test_repository.py`) auto-skip when Supabase is unavailable. +Integration tests (`test_repository.py`) auto-skip when `SUPABASE_CONNECTION_STRING` is unset. --- diff --git a/tests/portfolio/conftest.py b/tests/portfolio/conftest.py index 073d318d..9a7d61f8 100644 --- a/tests/portfolio/conftest.py +++ b/tests/portfolio/conftest.py @@ -2,16 +2,16 @@ Fixtures provided: -- ``tmp_reports`` — temporary directory used as ReportStore base_dir -- ``sample_portfolio`` — a Portfolio instance for testing (not persisted) -- ``sample_holding`` — a Holding instance for testing (not persisted) -- ``sample_trade`` — a Trade instance for testing (not persisted) -- ``sample_snapshot`` — a PortfolioSnapshot instance for testing -- ``report_store`` — a ReportStore instance backed by tmp_reports -- ``mock_supabase_client`` — MagicMock of SupabaseClient for unit tests +- ``tmp_reports`` -- temporary directory used as ReportStore base_dir +- ``sample_portfolio`` -- a Portfolio instance for testing (not persisted) +- ``sample_holding`` -- a Holding instance for testing (not persisted) +- ``sample_trade`` -- a Trade instance for testing (not persisted) +- ``sample_snapshot`` -- a PortfolioSnapshot instance for testing +- ``report_store`` -- a ReportStore instance backed by tmp_reports +- ``mock_supabase_client`` -- MagicMock of SupabaseClient for unit tests Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when -``SUPABASE_URL`` is not set in the environment. +``SUPABASE_CONNECTION_STRING`` is not set in the environment. """ from __future__ import annotations @@ -36,8 +36,8 @@ from tradingagents.portfolio.supabase_client import SupabaseClient # --------------------------------------------------------------------------- requires_supabase = pytest.mark.skipif( - not os.getenv("SUPABASE_URL"), - reason="SUPABASE_URL not set — skipping Supabase integration tests", + not os.getenv("SUPABASE_CONNECTION_STRING"), + reason="SUPABASE_CONNECTION_STRING not set -- skipping integration tests", ) diff --git a/tests/portfolio/test_repository.py b/tests/portfolio/test_repository.py index 3658ab26..958d93b9 100644 --- a/tests/portfolio/test_repository.py +++ b/tests/portfolio/test_repository.py @@ -1,29 +1,59 @@ """Tests for tradingagents/portfolio/repository.py. -Tests the PortfolioRepository façade — business logic for holdings management, -cash accounting, avg-cost-basis updates, and snapshot creation. - -Supabase integration tests are automatically skipped when ``SUPABASE_URL`` is -not set in the environment (use the ``requires_supabase`` fixture marker). - Unit tests use ``mock_supabase_client`` to avoid DB access. +Integration tests auto-skip when ``SUPABASE_CONNECTION_STRING`` is unset. Run (unit tests only):: pytest tests/portfolio/test_repository.py -v -k "not integration" - -Run (with Supabase):: - - SUPABASE_URL=... SUPABASE_KEY=... pytest tests/portfolio/test_repository.py -v """ from __future__ import annotations +from unittest.mock import MagicMock, call + import pytest +from tradingagents.portfolio.exceptions import ( + HoldingNotFoundError, + InsufficientCashError, + InsufficientSharesError, +) +from tradingagents.portfolio.models import Holding, Portfolio, Trade +from tradingagents.portfolio.repository import PortfolioRepository from tests.portfolio.conftest import requires_supabase +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_repo(mock_client, report_store): + """Build a PortfolioRepository with mock client and real report store.""" + return PortfolioRepository( + client=mock_client, + store=report_store, + config={"data_dir": "reports", "max_positions": 15, + "max_position_pct": 0.15, "max_sector_pct": 0.35, + "min_cash_pct": 0.05, "default_budget": 100_000.0, + "supabase_connection_string": ""}, + ) + + +def _mock_portfolio(pid="pid-1", cash=10_000.0): + return Portfolio( + portfolio_id=pid, name="Test", cash=cash, + initial_cash=100_000.0, currency="USD", + ) + + +def _mock_holding(pid="pid-1", ticker="AAPL", shares=50.0, avg_cost=190.0): + return Holding( + holding_id="hid-1", portfolio_id=pid, ticker=ticker, + shares=shares, avg_cost=avg_cost, + ) + + # --------------------------------------------------------------------------- # add_holding — new position # --------------------------------------------------------------------------- @@ -31,12 +61,19 @@ from tests.portfolio.conftest import requires_supabase def test_add_holding_new_position(mock_supabase_client, report_store): """add_holding() on a ticker not yet held must create a new Holding.""" - # TODO: implement - # repo = PortfolioRepository(client=mock_supabase_client, store=report_store) - # mock portfolio with enough cash - # repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0) - # assert mock_supabase_client.upsert_holding.called - raise NotImplementedError + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0) + mock_supabase_client.get_holding.return_value = None + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + holding = repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) + + assert holding.ticker == "AAPL" + assert holding.shares == 10 + assert holding.avg_cost == 200.0 + assert mock_supabase_client.upsert_holding.called # --------------------------------------------------------------------------- @@ -45,16 +82,21 @@ def test_add_holding_new_position(mock_supabase_client, report_store): def test_add_holding_updates_avg_cost(mock_supabase_client, report_store): - """add_holding() on an existing position must update avg_cost correctly. + """add_holding() on existing position must update avg_cost correctly.""" + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0) + existing = _mock_holding(shares=50.0, avg_cost=190.0) + mock_supabase_client.get_holding.return_value = existing + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t - Formula: new_avg_cost = (old_shares * old_avg_cost + new_shares * price) - / (old_shares + new_shares) - """ - # TODO: implement - # existing holding: 50 shares @ 190.0 - # buy 25 more @ 200.0 - # expected avg_cost = (50*190 + 25*200) / 75 = 193.33... - raise NotImplementedError + repo = _make_repo(mock_supabase_client, report_store) + holding = repo.add_holding("pid-1", "AAPL", shares=25, price=200.0) + + # expected: (50*190 + 25*200) / 75 = 193.333... + expected_avg = (50 * 190.0 + 25 * 200.0) / 75 + assert holding.shares == 75 + assert holding.avg_cost == pytest.approx(expected_avg) # --------------------------------------------------------------------------- @@ -64,11 +106,11 @@ def test_add_holding_updates_avg_cost(mock_supabase_client, report_store): def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store): """add_holding() must raise InsufficientCashError when cash < shares * price.""" - # TODO: implement - # portfolio with cash=500.0, try to buy 10 shares @ 200.0 (cost=2000) - # with pytest.raises(InsufficientCashError): - # repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0) - raise NotImplementedError + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=500.0) + + repo = _make_repo(mock_supabase_client, report_store) + with pytest.raises(InsufficientCashError): + repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) # --------------------------------------------------------------------------- @@ -78,11 +120,17 @@ def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store def test_remove_holding_full_position(mock_supabase_client, report_store): """remove_holding() selling all shares must delete the holding row.""" - # TODO: implement - # holding: 50 shares - # sell 50 shares → holding deleted, cash credited - # assert mock_supabase_client.delete_holding.called - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0) + mock_supabase_client.delete_holding.return_value = None + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + result = repo.remove_holding("pid-1", "AAPL", shares=50.0, price=200.0) + + assert result is None + assert mock_supabase_client.delete_holding.called # --------------------------------------------------------------------------- @@ -92,10 +140,17 @@ def test_remove_holding_full_position(mock_supabase_client, report_store): def test_remove_holding_partial_position(mock_supabase_client, report_store): """remove_holding() selling a subset must reduce shares, not delete.""" - # TODO: implement - # holding: 50 shares - # sell 20 → holding.shares == 30, avg_cost unchanged - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0) + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + result = repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) + + assert result is not None + assert result.shares == 30.0 # --------------------------------------------------------------------------- @@ -105,16 +160,20 @@ def test_remove_holding_partial_position(mock_supabase_client, report_store): def test_remove_holding_raises_insufficient_shares(mock_supabase_client, report_store): """remove_holding() must raise InsufficientSharesError when shares > held.""" - # TODO: implement - # holding: 10 shares - # try sell 20 → InsufficientSharesError - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=10.0) + + repo = _make_repo(mock_supabase_client, report_store) + with pytest.raises(InsufficientSharesError): + repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report_store): """remove_holding() must raise HoldingNotFoundError for unknown tickers.""" - # TODO: implement - raise NotImplementedError + mock_supabase_client.get_holding.return_value = None + + repo = _make_repo(mock_supabase_client, report_store) + with pytest.raises(HoldingNotFoundError): + repo.remove_holding("pid-1", "ZZZZ", shares=10.0, price=100.0) # --------------------------------------------------------------------------- @@ -124,15 +183,35 @@ def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report def test_add_holding_deducts_cash(mock_supabase_client, report_store): """add_holding() must reduce portfolio.cash by shares * price.""" - # TODO: implement - # portfolio.cash = 10_000, buy 10 @ 200 → cash should be 8_000 - raise NotImplementedError + portfolio = _mock_portfolio(cash=10_000.0) + mock_supabase_client.get_portfolio.return_value = portfolio + mock_supabase_client.get_holding.return_value = None + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) + + # Check the portfolio passed to update_portfolio had cash deducted + updated = mock_supabase_client.update_portfolio.call_args[0][0] + assert updated.cash == pytest.approx(8_000.0) def test_remove_holding_credits_cash(mock_supabase_client, report_store): """remove_holding() must increase portfolio.cash by shares * price.""" - # TODO: implement - raise NotImplementedError + portfolio = _mock_portfolio(cash=5_000.0) + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = portfolio + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) + + updated = mock_supabase_client.update_portfolio.call_args[0][0] + assert updated.cash == pytest.approx(9_000.0) # --------------------------------------------------------------------------- @@ -142,14 +221,37 @@ def test_remove_holding_credits_cash(mock_supabase_client, report_store): def test_add_holding_records_buy_trade(mock_supabase_client, report_store): """add_holding() must call client.record_trade() with action='BUY'.""" - # TODO: implement - raise NotImplementedError + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0) + mock_supabase_client.get_holding.return_value = None + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) + + trade = mock_supabase_client.record_trade.call_args[0][0] + assert trade.action == "BUY" + assert trade.ticker == "AAPL" + assert trade.shares == 10 + assert trade.total_value == pytest.approx(2_000.0) def test_remove_holding_records_sell_trade(mock_supabase_client, report_store): """remove_holding() must call client.record_trade() with action='SELL'.""" - # TODO: implement - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0) + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) + + trade = mock_supabase_client.record_trade.call_args[0][0] + assert trade.action == "SELL" + assert trade.ticker == "AAPL" + assert trade.shares == 20.0 # --------------------------------------------------------------------------- @@ -159,40 +261,120 @@ def test_remove_holding_records_sell_trade(mock_supabase_client, report_store): def test_take_snapshot(mock_supabase_client, report_store): """take_snapshot() must enrich holdings and persist a PortfolioSnapshot.""" - # TODO: implement - # assert mock_supabase_client.save_snapshot.called - # snapshot.total_value == cash + equity - raise NotImplementedError + portfolio = _mock_portfolio(cash=5_000.0) + holding = _mock_holding(shares=50.0, avg_cost=190.0) + mock_supabase_client.get_portfolio.return_value = portfolio + mock_supabase_client.list_holdings.return_value = [holding] + mock_supabase_client.save_snapshot.side_effect = lambda s: s + + repo = _make_repo(mock_supabase_client, report_store) + snapshot = repo.take_snapshot("pid-1", prices={"AAPL": 200.0}) + + assert mock_supabase_client.save_snapshot.called + assert snapshot.cash == 5_000.0 + assert snapshot.num_positions == 1 + assert snapshot.total_value == pytest.approx(5_000.0 + 50.0 * 200.0) # --------------------------------------------------------------------------- -# Supabase integration tests (auto-skip without SUPABASE_URL) +# Supabase integration tests (auto-skip without SUPABASE_CONNECTION_STRING) # --------------------------------------------------------------------------- @requires_supabase def test_integration_create_and_get_portfolio(): """Integration: create a portfolio, retrieve it, verify fields match.""" - # TODO: implement - raise NotImplementedError + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Integration Test", initial_cash=50_000.0) + try: + fetched = repo.get_portfolio(portfolio.portfolio_id) + assert fetched.name == "Integration Test" + assert fetched.cash == pytest.approx(50_000.0) + assert fetched.initial_cash == pytest.approx(50_000.0) + finally: + client.delete_portfolio(portfolio.portfolio_id) @requires_supabase def test_integration_add_and_remove_holding(): - """Integration: add holding, verify DB row; remove, verify deletion.""" - # TODO: implement - raise NotImplementedError + """Integration: add holding, verify; remove, verify deletion.""" + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Hold Test", initial_cash=50_000.0) + try: + holding = repo.add_holding( + portfolio.portfolio_id, "AAPL", shares=10, price=200.0, + sector="Technology", + ) + assert holding.ticker == "AAPL" + assert holding.shares == 10 + + # Verify cash deducted + p = repo.get_portfolio(portfolio.portfolio_id) + assert p.cash == pytest.approx(48_000.0) + + # Sell all + result = repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=10, price=210.0) + assert result is None + + # Verify cash credited + p = repo.get_portfolio(portfolio.portfolio_id) + assert p.cash == pytest.approx(50_100.0) + finally: + client.delete_portfolio(portfolio.portfolio_id) @requires_supabase def test_integration_record_and_list_trades(): - """Integration: record BUY + SELL trades, list them, verify order.""" - # TODO: implement - raise NotImplementedError + """Integration: trades are recorded automatically via add/remove holding.""" + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Trade Test", initial_cash=50_000.0) + try: + repo.add_holding(portfolio.portfolio_id, "MSFT", shares=5, price=400.0) + repo.remove_holding(portfolio.portfolio_id, "MSFT", shares=5, price=410.0) + + trades = client.list_trades(portfolio.portfolio_id) + assert len(trades) == 2 + assert trades[0].action == "SELL" # newest first + assert trades[1].action == "BUY" + finally: + client.delete_portfolio(portfolio.portfolio_id) @requires_supabase def test_integration_save_and_load_snapshot(): """Integration: take snapshot, retrieve latest, verify total_value.""" - # TODO: implement - raise NotImplementedError + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Snap Test", initial_cash=50_000.0) + try: + repo.add_holding(portfolio.portfolio_id, "AAPL", shares=10, price=200.0) + snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 210.0}) + + assert snapshot.num_positions == 1 + assert snapshot.cash == pytest.approx(48_000.0) + assert snapshot.total_value == pytest.approx(48_000.0 + 10 * 210.0) + + latest = client.get_latest_snapshot(portfolio.portfolio_id) + assert latest is not None + assert latest.snapshot_id == snapshot.snapshot_id + finally: + client.delete_portfolio(portfolio.portfolio_id) diff --git a/tradingagents/portfolio/config.py b/tradingagents/portfolio/config.py index 4ffbc403..d8293e49 100644 --- a/tradingagents/portfolio/config.py +++ b/tradingagents/portfolio/config.py @@ -1,8 +1,7 @@ """Portfolio Manager configuration. -Reads all portfolio-related settings from environment variables with sensible -defaults. Integrates with the existing ``tradingagents/default_config.py`` -pattern. +Integrates with the existing ``tradingagents/default_config.py`` pattern, +reading all portfolio settings from ``TRADINGAGENTS_`` env vars. Usage:: @@ -11,34 +10,51 @@ Usage:: cfg = get_portfolio_config() validate_config(cfg) print(cfg["max_positions"]) # 15 - -Environment variables (all optional): - - SUPABASE_URL Supabase project URL (default: "") - SUPABASE_KEY Supabase anon/service role key (default: "") - PORTFOLIO_DATA_DIR Root dir for filesystem reports (default: "reports") - PM_MAX_POSITIONS Max open positions (default: 15) - PM_MAX_POSITION_PCT Max single-position weight 0–1 (default: 0.15) - PM_MAX_SECTOR_PCT Max sector weight 0–1 (default: 0.35) - PM_MIN_CASH_PCT Minimum cash reserve 0–1 (default: 0.05) - PM_DEFAULT_BUDGET Default starting cash in USD (default: 100000.0) """ from __future__ import annotations import os -# --------------------------------------------------------------------------- -# Defaults -# --------------------------------------------------------------------------- +from dotenv import load_dotenv + +load_dotenv() + + +def _env(key: str, default=None): + """Read ``TRADINGAGENTS_`` from the environment. + + Matches the convention in ``tradingagents/default_config.py``. + """ + val = os.getenv(f"TRADINGAGENTS_{key.upper()}") + if not val: + return default + return val + + +def _env_float(key: str, default=None): + val = _env(key) + if val is None: + return default + try: + return float(val) + except (ValueError, TypeError): + return default + + +def _env_int(key: str, default=None): + val = _env(key) + if val is None: + return default + try: + return int(val) + except (ValueError, TypeError): + return default + PORTFOLIO_CONFIG: dict = { - # Supabase connection - "supabase_url": "", - "supabase_key": "", - # Filesystem report root (matches report_paths.py REPORTS_ROOT) - "data_dir": "reports", - # PM constraint defaults + "supabase_connection_string": os.getenv("SUPABASE_CONNECTION_STRING", ""), + "data_dir": _env("PORTFOLIO_DATA_DIR", "reports"), "max_positions": 15, "max_position_pct": 0.15, "max_sector_pct": 0.35, @@ -47,41 +63,44 @@ PORTFOLIO_CONFIG: dict = { } -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - def get_portfolio_config() -> dict: """Return the merged portfolio config (defaults overridden by env vars). - Reads ``SUPABASE_URL``, ``SUPABASE_KEY``, ``PORTFOLIO_DATA_DIR``, - ``PM_MAX_POSITIONS``, ``PM_MAX_POSITION_PCT``, ``PM_MAX_SECTOR_PCT``, - ``PM_MIN_CASH_PCT``, and ``PM_DEFAULT_BUDGET`` from the environment. - Returns: A dict with all portfolio configuration keys. """ - # TODO: implement — merge PORTFOLIO_CONFIG with env var overrides - raise NotImplementedError + cfg = dict(PORTFOLIO_CONFIG) + cfg["supabase_connection_string"] = os.getenv("SUPABASE_CONNECTION_STRING", cfg["supabase_connection_string"]) + cfg["data_dir"] = _env("PORTFOLIO_DATA_DIR", cfg["data_dir"]) + cfg["max_positions"] = _env_int("PM_MAX_POSITIONS", cfg["max_positions"]) + cfg["max_position_pct"] = _env_float("PM_MAX_POSITION_PCT", cfg["max_position_pct"]) + cfg["max_sector_pct"] = _env_float("PM_MAX_SECTOR_PCT", cfg["max_sector_pct"]) + cfg["min_cash_pct"] = _env_float("PM_MIN_CASH_PCT", cfg["min_cash_pct"]) + cfg["default_budget"] = _env_float("PM_DEFAULT_BUDGET", cfg["default_budget"]) + return cfg def validate_config(cfg: dict) -> None: """Validate a portfolio config dict, raising ValueError on invalid values. - Checks: - - ``max_positions >= 1`` - - ``0 < max_position_pct <= 1`` - - ``0 < max_sector_pct <= 1`` - - ``0 <= min_cash_pct < 1`` - - ``default_budget > 0`` - - ``min_cash_pct + max_position_pct <= 1`` (can always meet both constraints) - Args: cfg: Config dict as returned by ``get_portfolio_config()``. Raises: ValueError: With a descriptive message on the first failed check. """ - # TODO: implement - raise NotImplementedError + if cfg["max_positions"] < 1: + raise ValueError(f"max_positions must be >= 1, got {cfg['max_positions']}") + if not (0 < cfg["max_position_pct"] <= 1.0): + raise ValueError(f"max_position_pct must be in (0, 1], got {cfg['max_position_pct']}") + if not (0 < cfg["max_sector_pct"] <= 1.0): + raise ValueError(f"max_sector_pct must be in (0, 1], got {cfg['max_sector_pct']}") + if not (0 <= cfg["min_cash_pct"] < 1.0): + raise ValueError(f"min_cash_pct must be in [0, 1), got {cfg['min_cash_pct']}") + if cfg["default_budget"] <= 0: + raise ValueError(f"default_budget must be > 0, got {cfg['default_budget']}") + if cfg["min_cash_pct"] + cfg["max_position_pct"] > 1.0: + raise ValueError( + f"min_cash_pct ({cfg['min_cash_pct']}) + max_position_pct ({cfg['max_position_pct']}) " + f"must be <= 1.0" + ) diff --git a/tradingagents/portfolio/models.py b/tradingagents/portfolio/models.py index 0cbfd442..f9989af4 100644 --- a/tradingagents/portfolio/models.py +++ b/tradingagents/portfolio/models.py @@ -106,7 +106,7 @@ class Portfolio: h.current_value for h in holdings if h.current_value is not None ) self.total_value = self.cash + self.equity_value - self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0 + self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 1.0 return self diff --git a/tradingagents/portfolio/repository.py b/tradingagents/portfolio/repository.py index 123803fd..0a556ef5 100644 --- a/tradingagents/portfolio/repository.py +++ b/tradingagents/portfolio/repository.py @@ -1,38 +1,28 @@ -"""Unified data-access façade for the Portfolio Manager. +"""Unified data-access facade for the Portfolio Manager. ``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and ``ReportStore`` (filesystem documents) into a single, business-logic-aware interface. -Callers should **only** interact with ``PortfolioRepository`` — do not use -``SupabaseClient`` or ``ReportStore`` directly from outside this package. - Usage:: from tradingagents.portfolio import PortfolioRepository repo = PortfolioRepository() - - # Create a portfolio portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0) - - # Buy shares holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50) - # Sell shares - repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=25, price=200.00) - - # Snapshot - snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 200.00}) - See ``docs/portfolio/04_repository_api.md`` for full API documentation. """ from __future__ import annotations +import uuid +from datetime import datetime, timezone from pathlib import Path from typing import Any +from tradingagents.portfolio.config import get_portfolio_config from tradingagents.portfolio.exceptions import ( HoldingNotFoundError, InsufficientCashError, @@ -49,7 +39,7 @@ from tradingagents.portfolio.supabase_client import SupabaseClient class PortfolioRepository: - """Unified façade over SupabaseClient and ReportStore. + """Unified facade over SupabaseClient and ReportStore. Implements business logic for: - Average cost basis updates on repeated buys @@ -64,16 +54,9 @@ class PortfolioRepository: store: ReportStore | None = None, config: dict[str, Any] | None = None, ) -> None: - """Initialise the repository. - - Args: - client: SupabaseClient instance. Uses ``SupabaseClient.get_instance()`` - when None. - store: ReportStore instance. Creates a default instance when None. - config: Portfolio config dict. Uses ``get_portfolio_config()`` when None. - """ - # TODO: implement — resolve defaults, store as self._client, self._store, self._cfg - raise NotImplementedError + self._cfg = config or get_portfolio_config() + self._client = client or SupabaseClient.get_instance() + self._store = store or ReportStore(base_dir=self._cfg["data_dir"]) # ------------------------------------------------------------------ # Portfolio lifecycle @@ -85,54 +68,42 @@ class PortfolioRepository: initial_cash: float, currency: str = "USD", ) -> Portfolio: - """Create a new portfolio with the given starting capital. - - Generates a UUID for ``portfolio_id``. Sets ``cash = initial_cash``. - - Args: - name: Human-readable portfolio name. - initial_cash: Starting capital in USD (or configured currency). - currency: ISO 4217 currency code. - - Returns: - Persisted Portfolio instance. - - Raises: - DuplicatePortfolioError: If the name is already in use. - ValueError: If ``initial_cash <= 0``. - """ - # TODO: implement - raise NotImplementedError + """Create a new portfolio with the given starting capital.""" + if initial_cash <= 0: + raise ValueError(f"initial_cash must be > 0, got {initial_cash}") + portfolio = Portfolio( + portfolio_id=str(uuid.uuid4()), + name=name, + cash=initial_cash, + initial_cash=initial_cash, + currency=currency, + ) + return self._client.create_portfolio(portfolio) def get_portfolio(self, portfolio_id: str) -> Portfolio: - """Fetch a portfolio by ID. - - Args: - portfolio_id: UUID of the target portfolio. - - Raises: - PortfolioNotFoundError: If not found. - """ - # TODO: implement - raise NotImplementedError + """Fetch a portfolio by ID.""" + return self._client.get_portfolio(portfolio_id) def get_portfolio_with_holdings( self, portfolio_id: str, prices: dict[str, float] | None = None, ) -> tuple[Portfolio, list[Holding]]: - """Fetch portfolio + all holdings, optionally enriched with current prices. - - Args: - portfolio_id: UUID of the target portfolio. - prices: Optional ``{ticker: current_price}`` dict. When provided, - holdings are enriched and ``Portfolio.enrich()`` is called. - - Returns: - ``(Portfolio, list[Holding])`` — enriched when prices are supplied. - """ - # TODO: implement - raise NotImplementedError + """Fetch portfolio + all holdings, optionally enriched with current prices.""" + portfolio = self._client.get_portfolio(portfolio_id) + holdings = self._client.list_holdings(portfolio_id) + if prices: + # First pass: compute equity for total_value + equity = sum( + prices.get(h.ticker, 0.0) * h.shares for h in holdings + ) + total_value = portfolio.cash + equity + # Second pass: enrich each holding with weight + for h in holdings: + if h.ticker in prices: + h.enrich(prices[h.ticker], total_value) + portfolio.enrich(holdings) + return portfolio, holdings # ------------------------------------------------------------------ # Holdings management @@ -147,37 +118,65 @@ class PortfolioRepository: sector: str | None = None, industry: str | None = None, ) -> Holding: - """Buy shares and update portfolio cash and holdings. + """Buy shares and update portfolio cash and holdings.""" + if shares <= 0: + raise ValueError(f"shares must be > 0, got {shares}") + if price <= 0: + raise ValueError(f"price must be > 0, got {price}") - Business logic: - - Raises ``InsufficientCashError`` if ``portfolio.cash < shares * price`` - - If holding already exists: updates ``avg_cost`` using weighted average - - ``portfolio.cash -= shares * price`` - - Records a BUY trade automatically + cost = shares * price + portfolio = self._client.get_portfolio(portfolio_id) - Avg cost formula:: + if portfolio.cash < cost: + raise InsufficientCashError( + f"Need ${cost:.2f} but only ${portfolio.cash:.2f} available" + ) - new_avg_cost = (old_shares * old_avg_cost + new_shares * price) - / (old_shares + new_shares) + # Check for existing holding to update avg cost + existing = self._client.get_holding(portfolio_id, ticker) + if existing: + new_total_shares = existing.shares + shares + new_avg_cost = ( + (existing.shares * existing.avg_cost + shares * price) / new_total_shares + ) + existing.shares = new_total_shares + existing.avg_cost = new_avg_cost + if sector: + existing.sector = sector + if industry: + existing.industry = industry + holding = self._client.upsert_holding(existing) + else: + holding = Holding( + holding_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + shares=shares, + avg_cost=price, + sector=sector, + industry=industry, + ) + holding = self._client.upsert_holding(holding) - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol. - shares: Number of shares to buy (must be > 0). - price: Execution price per share. - sector: Optional GICS sector name. - industry: Optional GICS industry name. + # Deduct cash + portfolio.cash -= cost + self._client.update_portfolio(portfolio) - Returns: - Updated or created Holding. + # Record trade + trade = Trade( + trade_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + action="BUY", + shares=shares, + price=price, + total_value=cost, + trade_date=datetime.now(timezone.utc).isoformat(), + signal_source="pm_agent", + ) + self._client.record_trade(trade) - Raises: - InsufficientCashError: If cash would go negative. - PortfolioNotFoundError: If portfolio_id does not exist. - ValueError: If shares <= 0 or price <= 0. - """ - # TODO: implement - raise NotImplementedError + return holding def remove_holding( self, @@ -186,33 +185,53 @@ class PortfolioRepository: shares: float, price: float, ) -> Holding | None: - """Sell shares and update portfolio cash and holdings. + """Sell shares and update portfolio cash and holdings.""" + if shares <= 0: + raise ValueError(f"shares must be > 0, got {shares}") + if price <= 0: + raise ValueError(f"price must be > 0, got {price}") - Business logic: - - Raises ``HoldingNotFoundError`` if no holding exists for ticker - - Raises ``InsufficientSharesError`` if ``holding.shares < shares`` - - If ``shares == holding.shares``: deletes the holding row, returns None - - Otherwise: decrements ``holding.shares`` (avg_cost unchanged on sell) - - ``portfolio.cash += shares * price`` - - Records a SELL trade automatically + existing = self._client.get_holding(portfolio_id, ticker) + if not existing: + raise HoldingNotFoundError( + f"No holding for {ticker} in portfolio {portfolio_id}" + ) - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol. - shares: Number of shares to sell (must be > 0). - price: Execution price per share. + if existing.shares < shares: + raise InsufficientSharesError( + f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}" + ) - Returns: - Updated Holding, or None if the position was fully closed. + proceeds = shares * price + portfolio = self._client.get_portfolio(portfolio_id) - Raises: - HoldingNotFoundError: If no holding exists for this ticker. - InsufficientSharesError: If holding.shares < shares. - PortfolioNotFoundError: If portfolio_id does not exist. - ValueError: If shares <= 0 or price <= 0. - """ - # TODO: implement - raise NotImplementedError + if existing.shares == shares: + # Full sell — delete holding + self._client.delete_holding(portfolio_id, ticker) + result = None + else: + existing.shares -= shares + result = self._client.upsert_holding(existing) + + # Credit cash + portfolio.cash += proceeds + self._client.update_portfolio(portfolio) + + # Record trade + trade = Trade( + trade_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + action="SELL", + shares=shares, + price=price, + total_value=proceeds, + trade_date=datetime.now(timezone.utc).isoformat(), + signal_source="pm_agent", + ) + self._client.record_trade(trade) + + return result # ------------------------------------------------------------------ # Snapshots @@ -223,20 +242,19 @@ class PortfolioRepository: portfolio_id: str, prices: dict[str, float], ) -> PortfolioSnapshot: - """Take an immutable snapshot of the current portfolio state. - - Fetches all holdings, enriches them with ``prices``, computes - ``total_value``, then persists via ``SupabaseClient.save_snapshot()``. - - Args: - portfolio_id: UUID of the target portfolio. - prices: ``{ticker: current_price}`` for all held tickers. - - Returns: - Persisted PortfolioSnapshot. - """ - # TODO: implement - raise NotImplementedError + """Take an immutable snapshot of the current portfolio state.""" + portfolio, holdings = self.get_portfolio_with_holdings(portfolio_id, prices) + snapshot = PortfolioSnapshot( + snapshot_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + snapshot_date=datetime.now(timezone.utc).isoformat(), + total_value=portfolio.total_value or portfolio.cash, + cash=portfolio.cash, + equity_value=portfolio.equity_value or 0.0, + num_positions=len(holdings), + holdings_snapshot=[h.to_dict() for h in holdings], + ) + return self._client.save_snapshot(snapshot) # ------------------------------------------------------------------ # Report convenience methods @@ -249,37 +267,21 @@ class PortfolioRepository: decision: dict[str, Any], markdown: str | None = None, ) -> Path: - """Save a PM agent decision and update portfolio.report_path. - - Delegates to ``ReportStore.save_pm_decision()`` then updates the - ``portfolio.report_path`` column in Supabase to point to the daily - portfolio directory. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string, e.g. ``"2026-03-20"``. - decision: PM decision dict. - markdown: Optional human-readable markdown version. - - Returns: - Path of the written JSON file. - """ - # TODO: implement - raise NotImplementedError + """Save a PM agent decision and update portfolio.report_path.""" + path = self._store.save_pm_decision(date, portfolio_id, decision, markdown) + # Update portfolio report_path + portfolio = self._client.get_portfolio(portfolio_id) + portfolio.report_path = str(self._store._portfolio_dir(date)) + self._client.update_portfolio(portfolio) + return path def load_pm_decision( self, portfolio_id: str, date: str, ) -> dict[str, Any] | None: - """Load a PM decision JSON. Returns None if not found. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string. - """ - # TODO: implement - raise NotImplementedError + """Load a PM decision JSON. Returns None if not found.""" + return self._store.load_pm_decision(date, portfolio_id) def save_risk_metrics( self, @@ -287,29 +289,13 @@ class PortfolioRepository: date: str, metrics: dict[str, Any], ) -> Path: - """Save risk computation results. Delegates to ReportStore. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string. - metrics: Risk metrics dict (Sharpe, Sortino, VaR, beta, etc.). - - Returns: - Path of the written file. - """ - # TODO: implement - raise NotImplementedError + """Save risk computation results.""" + return self._store.save_risk_metrics(date, portfolio_id, metrics) def load_risk_metrics( self, portfolio_id: str, date: str, ) -> dict[str, Any] | None: - """Load risk metrics. Returns None if not found. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string. - """ - # TODO: implement - raise NotImplementedError + """Load risk metrics. Returns None if not found.""" + return self._store.load_risk_metrics(date, portfolio_id) diff --git a/tradingagents/portfolio/supabase_client.py b/tradingagents/portfolio/supabase_client.py index efd87b34..c0107544 100644 --- a/tradingagents/portfolio/supabase_client.py +++ b/tradingagents/portfolio/supabase_client.py @@ -1,34 +1,30 @@ -"""Supabase database client for the Portfolio Manager. +"""PostgreSQL database client for the Portfolio Manager. -Thin wrapper around ``supabase-py`` (no ORM) that: -- Provides a singleton connection (one client per process) -- Translates Supabase / HTTP errors into domain exceptions -- Converts raw DB rows into typed model instances via ``Model.from_dict()`` - -**No ORM is used here by design** — see -``docs/agent/decisions/012-portfolio-no-orm.md`` for the full rationale. -In short: ``supabase-py``'s builder-pattern API is sufficient for 4 tables; -Prisma and SQLAlchemy add build-step / runtime complexity that isn't warranted -for this non-DB-heavy feature. +Uses ``psycopg2`` with the ``SUPABASE_CONNECTION_STRING`` env var to talk +directly to the Supabase-hosted PostgreSQL database. No ORM — see +``docs/agent/decisions/012-portfolio-no-orm.md`` for rationale. Usage:: from tradingagents.portfolio.supabase_client import SupabaseClient - from tradingagents.portfolio.models import Portfolio client = SupabaseClient.get_instance() portfolio = client.get_portfolio("some-uuid") - -Configuration (read from environment via ``get_portfolio_config()``): - SUPABASE_URL — Supabase project URL - SUPABASE_KEY — Supabase anon or service-role key """ from __future__ import annotations +import json +import uuid + +import psycopg2 +import psycopg2.extras + +from tradingagents.portfolio.config import get_portfolio_config from tradingagents.portfolio.exceptions import ( DuplicatePortfolioError, HoldingNotFoundError, + PortfolioError, PortfolioNotFoundError, ) from tradingagents.portfolio.models import ( @@ -40,167 +36,213 @@ from tradingagents.portfolio.models import ( class SupabaseClient: - """Singleton Supabase CRUD client for portfolio data. + """Singleton PostgreSQL CRUD client for portfolio data. - All public methods translate Supabase / HTTP errors into domain exceptions + All public methods translate database errors into domain exceptions and return typed model instances. - - Do not instantiate directly — use ``SupabaseClient.get_instance()``. """ - _instance: "SupabaseClient | None" = None + _instance: SupabaseClient | None = None - def __init__(self, url: str, key: str) -> None: - """Initialise the Supabase client. + def __init__(self, connection_string: str) -> None: + self._dsn = self._fix_dsn(connection_string) + self._conn = psycopg2.connect(self._dsn) + self._conn.autocommit = True - Args: - url: Supabase project URL. - key: Supabase anon or service-role key. - """ - # TODO: implement — create supabase.create_client(url, key) - raise NotImplementedError + @staticmethod + def _fix_dsn(dsn: str) -> str: + """URL-encode the password if it contains special characters.""" + from urllib.parse import quote + if "://" not in dsn: + return dsn # already key=value format + scheme, rest = dsn.split("://", 1) + at_idx = rest.rfind("@") + if at_idx == -1: + return dsn + userinfo = rest[:at_idx] + hostinfo = rest[at_idx + 1:] + colon_idx = userinfo.find(":") + if colon_idx == -1: + return dsn + user = userinfo[:colon_idx] + password = userinfo[colon_idx + 1:] + encoded = quote(password, safe="") + return f"{scheme}://{user}:{encoded}@{hostinfo}" @classmethod - def get_instance(cls) -> "SupabaseClient": - """Return the singleton instance, creating it if necessary. + def get_instance(cls) -> SupabaseClient: + """Return the singleton instance, creating it if necessary.""" + if cls._instance is None: + cfg = get_portfolio_config() + dsn = cfg["supabase_connection_string"] + if not dsn: + raise PortfolioError( + "SUPABASE_CONNECTION_STRING not configured. " + "Set it in .env or as an environment variable." + ) + cls._instance = cls(dsn) + return cls._instance - Reads SUPABASE_URL and SUPABASE_KEY from ``get_portfolio_config()``. + @classmethod + def reset_instance(cls) -> None: + """Close and reset the singleton (for testing).""" + if cls._instance is not None: + try: + cls._instance._conn.close() + except Exception: + pass + cls._instance = None - Raises: - PortfolioError: If SUPABASE_URL or SUPABASE_KEY are not configured. - """ - # TODO: implement - raise NotImplementedError + def _cursor(self): + """Return a RealDictCursor.""" + return self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) # ------------------------------------------------------------------ # Portfolio CRUD # ------------------------------------------------------------------ def create_portfolio(self, portfolio: Portfolio) -> Portfolio: - """Insert a new portfolio row. - - Args: - portfolio: Portfolio instance with all required fields set. - - Returns: - Portfolio with DB-assigned timestamps. - - Raises: - DuplicatePortfolioError: If portfolio_id already exists. - """ - # TODO: implement - raise NotImplementedError + """Insert a new portfolio row.""" + pid = portfolio.portfolio_id or str(uuid.uuid4()) + try: + with self._cursor() as cur: + cur.execute( + """INSERT INTO portfolios + (portfolio_id, name, cash, initial_cash, currency, report_path, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s) + RETURNING *""", + (pid, portfolio.name, portfolio.cash, portfolio.initial_cash, + portfolio.currency, portfolio.report_path, + json.dumps(portfolio.metadata)), + ) + row = cur.fetchone() + except psycopg2.errors.UniqueViolation as exc: + raise DuplicatePortfolioError(f"Portfolio already exists: {pid}") from exc + return self._row_to_portfolio(row) def get_portfolio(self, portfolio_id: str) -> Portfolio: - """Fetch a portfolio by ID. - - Args: - portfolio_id: UUID of the target portfolio. - - Returns: - Portfolio instance. - - Raises: - PortfolioNotFoundError: If no portfolio has that ID. - """ - # TODO: implement - raise NotImplementedError + """Fetch a portfolio by ID.""" + with self._cursor() as cur: + cur.execute("SELECT * FROM portfolios WHERE portfolio_id = %s", (portfolio_id,)) + row = cur.fetchone() + if not row: + raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}") + return self._row_to_portfolio(row) def list_portfolios(self) -> list[Portfolio]: """Return all portfolios ordered by created_at DESC.""" - # TODO: implement - raise NotImplementedError + with self._cursor() as cur: + cur.execute("SELECT * FROM portfolios ORDER BY created_at DESC") + rows = cur.fetchall() + return [self._row_to_portfolio(r) for r in rows] def update_portfolio(self, portfolio: Portfolio) -> Portfolio: - """Update mutable portfolio fields (cash, report_path, metadata). - - Args: - portfolio: Portfolio with updated field values. - - Returns: - Updated Portfolio with refreshed updated_at. - - Raises: - PortfolioNotFoundError: If portfolio_id does not exist. - """ - # TODO: implement - raise NotImplementedError + """Update mutable portfolio fields (cash, report_path, metadata).""" + with self._cursor() as cur: + cur.execute( + """UPDATE portfolios + SET cash = %s, report_path = %s, metadata = %s + WHERE portfolio_id = %s + RETURNING *""", + (portfolio.cash, portfolio.report_path, + json.dumps(portfolio.metadata), portfolio.portfolio_id), + ) + row = cur.fetchone() + if not row: + raise PortfolioNotFoundError(f"Portfolio not found: {portfolio.portfolio_id}") + return self._row_to_portfolio(row) def delete_portfolio(self, portfolio_id: str) -> None: - """Delete a portfolio and all associated data (CASCADE). - - Args: - portfolio_id: UUID of the portfolio to delete. - - Raises: - PortfolioNotFoundError: If portfolio_id does not exist. - """ - # TODO: implement - raise NotImplementedError + """Delete a portfolio and all associated data (CASCADE).""" + with self._cursor() as cur: + cur.execute( + "DELETE FROM portfolios WHERE portfolio_id = %s RETURNING portfolio_id", + (portfolio_id,), + ) + row = cur.fetchone() + if not row: + raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}") # ------------------------------------------------------------------ # Holdings CRUD # ------------------------------------------------------------------ def upsert_holding(self, holding: Holding) -> Holding: - """Insert or update a holding row (upsert on portfolio_id + ticker). - - Args: - holding: Holding instance with all required fields set. - - Returns: - Holding with DB-assigned / refreshed timestamps. - """ - # TODO: implement - raise NotImplementedError + """Insert or update a holding row (upsert on portfolio_id + ticker).""" + hid = holding.holding_id or str(uuid.uuid4()) + with self._cursor() as cur: + cur.execute( + """INSERT INTO holdings + (holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique + DO UPDATE SET shares = EXCLUDED.shares, + avg_cost = EXCLUDED.avg_cost, + sector = EXCLUDED.sector, + industry = EXCLUDED.industry + RETURNING *""", + (hid, holding.portfolio_id, holding.ticker.upper(), + holding.shares, holding.avg_cost, holding.sector, holding.industry), + ) + row = cur.fetchone() + return self._row_to_holding(row) def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None: - """Return the holding for (portfolio_id, ticker), or None if not found. - - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol (case-insensitive, stored as uppercase). - """ - # TODO: implement - raise NotImplementedError + """Return the holding for (portfolio_id, ticker), or None.""" + with self._cursor() as cur: + cur.execute( + "SELECT * FROM holdings WHERE portfolio_id = %s AND ticker = %s", + (portfolio_id, ticker.upper()), + ) + row = cur.fetchone() + return self._row_to_holding(row) if row else None def list_holdings(self, portfolio_id: str) -> list[Holding]: - """Return all holdings for a portfolio ordered by cost_basis DESC. - - Args: - portfolio_id: UUID of the target portfolio. - """ - # TODO: implement - raise NotImplementedError + """Return all holdings for a portfolio ordered by cost_basis DESC.""" + with self._cursor() as cur: + cur.execute( + """SELECT * FROM holdings + WHERE portfolio_id = %s + ORDER BY shares * avg_cost DESC""", + (portfolio_id,), + ) + rows = cur.fetchall() + return [self._row_to_holding(r) for r in rows] def delete_holding(self, portfolio_id: str, ticker: str) -> None: - """Delete the holding for (portfolio_id, ticker). - - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol. - - Raises: - HoldingNotFoundError: If no such holding exists. - """ - # TODO: implement - raise NotImplementedError + """Delete the holding for (portfolio_id, ticker).""" + with self._cursor() as cur: + cur.execute( + "DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s RETURNING holding_id", + (portfolio_id, ticker.upper()), + ) + row = cur.fetchone() + if not row: + raise HoldingNotFoundError( + f"Holding not found: {ticker} in portfolio {portfolio_id}" + ) # ------------------------------------------------------------------ # Trades # ------------------------------------------------------------------ def record_trade(self, trade: Trade) -> Trade: - """Insert a new trade record. Immutable — no update method. - - Args: - trade: Trade instance with all required fields set. - - Returns: - Trade with DB-assigned trade_id and trade_date. - """ - # TODO: implement - raise NotImplementedError + """Insert a new trade record.""" + tid = trade.trade_id or str(uuid.uuid4()) + with self._cursor() as cur: + cur.execute( + """INSERT INTO trades + (trade_id, portfolio_id, ticker, action, shares, price, + total_value, rationale, signal_source, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + RETURNING *""", + (tid, trade.portfolio_id, trade.ticker, trade.action, + trade.shares, trade.price, trade.total_value, + trade.rationale, trade.signal_source, + json.dumps(trade.metadata)), + ) + row = cur.fetchone() + return self._row_to_trade(row) def list_trades( self, @@ -208,51 +250,142 @@ class SupabaseClient: ticker: str | None = None, limit: int = 100, ) -> list[Trade]: - """Return recent trades for a portfolio, newest first. - - Args: - portfolio_id: Filter by portfolio. - ticker: Optional additional filter by ticker symbol. - limit: Maximum number of rows to return. - """ - # TODO: implement - raise NotImplementedError + """Return recent trades for a portfolio, newest first.""" + if ticker: + query = """SELECT * FROM trades + WHERE portfolio_id = %s AND ticker = %s + ORDER BY trade_date DESC LIMIT %s""" + params = (portfolio_id, ticker.upper(), limit) + else: + query = """SELECT * FROM trades + WHERE portfolio_id = %s + ORDER BY trade_date DESC LIMIT %s""" + params = (portfolio_id, limit) + with self._cursor() as cur: + cur.execute(query, params) + rows = cur.fetchall() + return [self._row_to_trade(r) for r in rows] # ------------------------------------------------------------------ # Snapshots # ------------------------------------------------------------------ def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot: - """Insert a new immutable portfolio snapshot. - - Args: - snapshot: PortfolioSnapshot with all required fields set. - - Returns: - Snapshot with DB-assigned snapshot_id and snapshot_date. - """ - # TODO: implement - raise NotImplementedError + """Insert a new immutable portfolio snapshot.""" + sid = snapshot.snapshot_id or str(uuid.uuid4()) + with self._cursor() as cur: + cur.execute( + """INSERT INTO snapshots + (snapshot_id, portfolio_id, total_value, cash, equity_value, + num_positions, holdings_snapshot, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + RETURNING *""", + (sid, snapshot.portfolio_id, snapshot.total_value, + snapshot.cash, snapshot.equity_value, snapshot.num_positions, + json.dumps(snapshot.holdings_snapshot), + json.dumps(snapshot.metadata)), + ) + row = cur.fetchone() + return self._row_to_snapshot(row) def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None: - """Return the most recent snapshot for a portfolio, or None. - - Args: - portfolio_id: UUID of the target portfolio. - """ - # TODO: implement - raise NotImplementedError + """Return the most recent snapshot for a portfolio, or None.""" + with self._cursor() as cur: + cur.execute( + """SELECT * FROM snapshots + WHERE portfolio_id = %s + ORDER BY snapshot_date DESC LIMIT 1""", + (portfolio_id,), + ) + row = cur.fetchone() + return self._row_to_snapshot(row) if row else None def list_snapshots( self, portfolio_id: str, limit: int = 30, ) -> list[PortfolioSnapshot]: - """Return snapshots newest-first up to limit. + """Return snapshots newest-first up to limit.""" + with self._cursor() as cur: + cur.execute( + """SELECT * FROM snapshots + WHERE portfolio_id = %s + ORDER BY snapshot_date DESC LIMIT %s""", + (portfolio_id, limit), + ) + rows = cur.fetchall() + return [self._row_to_snapshot(r) for r in rows] - Args: - portfolio_id: UUID of the target portfolio. - limit: Maximum number of snapshots to return. - """ - # TODO: implement - raise NotImplementedError + # ------------------------------------------------------------------ + # Row -> Model helpers + # ------------------------------------------------------------------ + + @staticmethod + def _row_to_portfolio(row: dict) -> Portfolio: + metadata = row.get("metadata") or {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + return Portfolio( + portfolio_id=str(row["portfolio_id"]), + name=row["name"], + cash=float(row["cash"]), + initial_cash=float(row["initial_cash"]), + currency=row["currency"].strip(), + created_at=str(row["created_at"]), + updated_at=str(row["updated_at"]), + report_path=row.get("report_path"), + metadata=metadata, + ) + + @staticmethod + def _row_to_holding(row: dict) -> Holding: + return Holding( + holding_id=str(row["holding_id"]), + portfolio_id=str(row["portfolio_id"]), + ticker=row["ticker"], + shares=float(row["shares"]), + avg_cost=float(row["avg_cost"]), + sector=row.get("sector"), + industry=row.get("industry"), + created_at=str(row["created_at"]), + updated_at=str(row["updated_at"]), + ) + + @staticmethod + def _row_to_trade(row: dict) -> Trade: + metadata = row.get("metadata") or {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + return Trade( + trade_id=str(row["trade_id"]), + portfolio_id=str(row["portfolio_id"]), + ticker=row["ticker"], + action=row["action"], + shares=float(row["shares"]), + price=float(row["price"]), + total_value=float(row["total_value"]), + trade_date=str(row["trade_date"]), + rationale=row.get("rationale"), + signal_source=row.get("signal_source"), + metadata=metadata, + ) + + @staticmethod + def _row_to_snapshot(row: dict) -> PortfolioSnapshot: + holdings = row.get("holdings_snapshot") or [] + if isinstance(holdings, str): + holdings = json.loads(holdings) + metadata = row.get("metadata") or {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + return PortfolioSnapshot( + snapshot_id=str(row["snapshot_id"]), + portfolio_id=str(row["portfolio_id"]), + snapshot_date=str(row["snapshot_date"]), + total_value=float(row["total_value"]), + cash=float(row["cash"]), + equity_value=float(row["equity_value"]), + num_positions=int(row["num_positions"]), + holdings_snapshot=holdings, + metadata=metadata, + )