feat: complete portfolio data foundation — psycopg2 client, repository, tests
Replace supabase-py stubs with working psycopg2 implementation using Supabase pooler connection string. Implement full business logic in repository (avg cost basis, cash accounting, trade recording, snapshots). Add 12 unit tests + 4 integration tests (51 total portfolio tests pass). Fix cash_pct bug in models.py, update docs for psycopg2 + pooler pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
aa4dcdeb80
commit
a17e5f3707
|
|
@ -1,20 +1,21 @@
|
|||
# Current Milestone
|
||||
|
||||
Daily digest consolidation and Google NotebookLM sync shipped (PR open: `feat/daily-digest-notebooklm`). All analyses now append to a single `daily_digest.md` per day and auto-upload to NotebookLM via `nlm` CLI. Next: PR review and merge.
|
||||
Portfolio Manager Phase 1 (data foundation) complete and merged. All 4 Supabase tables live, 51 tests passing (including integration tests against live DB).
|
||||
|
||||
# Recent Progress
|
||||
|
||||
- **PR #32 merged**: Portfolio Manager data foundation — models, SQL schema, module scaffolding
|
||||
- `tradingagents/portfolio/` — full module: models, config, exceptions, supabase_client (psycopg2), report_store, repository
|
||||
- `migrations/001_initial_schema.sql` — 4 tables (portfolios, holdings, trades, snapshots) with constraints, indexes, triggers
|
||||
- `tests/portfolio/` — 51 tests: 20 model, 15 report_store, 12 repository unit, 4 integration
|
||||
- Uses `psycopg2` direct PostgreSQL via Supabase pooler (`aws-1-eu-west-1.pooler.supabase.com:6543`)
|
||||
- Business logic: avg cost basis, cash accounting, trade recording, snapshots
|
||||
- **PR #22 merged**: Unified report paths, structured observability logging, memory system update
|
||||
- **feat/daily-digest-notebooklm** (shipped): Daily digest consolidation + NotebookLM source sync
|
||||
- `tradingagents/daily_digest.py` — `append_to_digest()` appends timestamped entries to `reports/daily/{date}/daily_digest.md`
|
||||
- `tradingagents/notebook_sync.py` — `sync_to_notebooklm()` deletes existing "Daily Trading Digest" source then uploads new content via `nlm source add --text --wait`.
|
||||
- `tradingagents/report_paths.py` — added `get_digest_path(date)`
|
||||
- `cli/main.py` — `analyze` and `scan` commands both call digest + sync after each run
|
||||
- `.env.example` — fixed consistency, removed duplicates, aligned with `NOTEBOOKLM_ID`
|
||||
- **Verification**: 220+ offline tests passing + 5 new unit tests for `notebook_sync.py` + live integration test passed.
|
||||
|
||||
# In Progress
|
||||
|
||||
- Portfolio Manager Phase 2: Holding Reviewer Agent (next)
|
||||
- Refinement of macro scan synthesis prompts (ongoing)
|
||||
|
||||
# Active Blockers
|
||||
|
|
|
|||
|
|
@ -72,28 +72,26 @@ investment portfolio end-to-end. It performs the following actions in sequence:
|
|||
The `report_path` column in the `portfolios` table points to the daily portfolio
|
||||
subdirectory on disk: `reports/daily/{date}/portfolio/`.
|
||||
|
||||
### Data Access Layer: raw `supabase-py` (no ORM)
|
||||
### Data Access Layer: raw `psycopg2` (no ORM)
|
||||
|
||||
The Python code talks to Supabase through the raw `supabase-py` client — **no
|
||||
ORM** (Prisma, SQLAlchemy, etc.) is used.
|
||||
The Python code talks to Supabase PostgreSQL directly via `psycopg2` using the
|
||||
**pooler connection string** (`SUPABASE_CONNECTION_STRING`). No ORM (Prisma,
|
||||
SQLAlchemy) and no `supabase-py` REST client is used.
|
||||
|
||||
**Why not Prisma?**
|
||||
- `prisma-client-py` requires a Node.js runtime for code generation — an
|
||||
extra non-Python dependency in a Python-only project.
|
||||
- Prisma's `prisma migrate` conflicts with Supabase's own SQL migration tooling
|
||||
(we use `.sql` files in `tradingagents/portfolio/migrations/`).
|
||||
- 4 tables with straightforward CRUD don't benefit from a code-generated ORM.
|
||||
**Why `psycopg2` over `supabase-py`?**
|
||||
- Direct SQL gives full control — transactions, upserts, `RETURNING *`, CTEs.
|
||||
- No dependency on Supabase's PostgREST schema cache or API key types.
|
||||
- `psycopg2-binary` is a single pip install with zero non-Python dependencies.
|
||||
- 4 tables with straightforward CRUD don't benefit from an ORM or REST wrapper.
|
||||
|
||||
**Why not SQLAlchemy?**
|
||||
- Supabase is accessed via PostgREST (HTTP API), not a direct TCP database
|
||||
connection. SQLAlchemy is designed for direct connections and would bypass
|
||||
Supabase's Row Level Security.
|
||||
- Extra dependency overhead for a non-DB-heavy feature.
|
||||
**Connection:**
|
||||
- Uses `SUPABASE_CONNECTION_STRING` env var (pooler URI format).
|
||||
- Passwords with special characters are auto-URL-encoded by `SupabaseClient._fix_dsn()`.
|
||||
- Typical pooler URI: `postgresql://postgres.<ref>:<password>@aws-1-<region>.pooler.supabase.com:6543/postgres`
|
||||
|
||||
**`supabase-py` is sufficient because:**
|
||||
- Its builder-pattern API (`client.table("holdings").select("*").eq(...)`)
|
||||
covers all needed queries cleanly.
|
||||
- Our own dataclasses handle type-safety via `to_dict()` / `from_dict()`.
|
||||
**Why not Prisma / SQLAlchemy?**
|
||||
- Prisma requires Node.js runtime — extra non-Python dependency.
|
||||
- SQLAlchemy adds dependency overhead for 4 simple tables.
|
||||
- Plain SQL migration files are readable, versionable, and Supabase-native.
|
||||
|
||||
> Full rationale: `docs/agent/decisions/012-portfolio-no-orm.md`
|
||||
|
|
|
|||
|
|
@ -89,14 +89,13 @@ All are optional and default to the values shown:
|
|||
|
||||
| Env Var | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `SUPABASE_URL` | `""` | Supabase project URL |
|
||||
| `SUPABASE_KEY` | `""` | Supabase anon/service key |
|
||||
| `PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports |
|
||||
| `PM_MAX_POSITIONS` | `15` | Max number of open positions |
|
||||
| `PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight |
|
||||
| `PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight |
|
||||
| `PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve |
|
||||
| `PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) |
|
||||
| `SUPABASE_CONNECTION_STRING` | `""` | PostgreSQL pooler connection URI |
|
||||
| `TRADINGAGENTS_PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports |
|
||||
| `TRADINGAGENTS_PM_MAX_POSITIONS` | `15` | Max number of open positions |
|
||||
| `TRADINGAGENTS_PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight |
|
||||
| `TRADINGAGENTS_PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight |
|
||||
| `TRADINGAGENTS_PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve |
|
||||
| `TRADINGAGENTS_PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) |
|
||||
|
||||
### Acceptance Criteria
|
||||
|
||||
|
|
@ -172,7 +171,7 @@ Methods that query a single row raise `PortfolioNotFoundError` when no row is fo
|
|||
|
||||
- Singleton — only one Supabase connection per process
|
||||
- All public methods fully type-annotated
|
||||
- Supabase integration tests auto-skip when `SUPABASE_URL` is unset
|
||||
- Supabase integration tests auto-skip when `SUPABASE_CONNECTION_STRING` is unset
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -328,7 +327,7 @@ require PG functions — deferred to Phase 3+).
|
|||
### Coverage Target
|
||||
|
||||
90 %+ for `models.py` and `report_store.py`.
|
||||
Integration tests (`test_repository.py`) auto-skip when Supabase is unavailable.
|
||||
Integration tests (`test_repository.py`) auto-skip when `SUPABASE_CONNECTION_STRING` is unset.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -2,16 +2,16 @@
|
|||
|
||||
Fixtures provided:
|
||||
|
||||
- ``tmp_reports`` — temporary directory used as ReportStore base_dir
|
||||
- ``sample_portfolio`` — a Portfolio instance for testing (not persisted)
|
||||
- ``sample_holding`` — a Holding instance for testing (not persisted)
|
||||
- ``sample_trade`` — a Trade instance for testing (not persisted)
|
||||
- ``sample_snapshot`` — a PortfolioSnapshot instance for testing
|
||||
- ``report_store`` — a ReportStore instance backed by tmp_reports
|
||||
- ``mock_supabase_client`` — MagicMock of SupabaseClient for unit tests
|
||||
- ``tmp_reports`` -- temporary directory used as ReportStore base_dir
|
||||
- ``sample_portfolio`` -- a Portfolio instance for testing (not persisted)
|
||||
- ``sample_holding`` -- a Holding instance for testing (not persisted)
|
||||
- ``sample_trade`` -- a Trade instance for testing (not persisted)
|
||||
- ``sample_snapshot`` -- a PortfolioSnapshot instance for testing
|
||||
- ``report_store`` -- a ReportStore instance backed by tmp_reports
|
||||
- ``mock_supabase_client`` -- MagicMock of SupabaseClient for unit tests
|
||||
|
||||
Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when
|
||||
``SUPABASE_URL`` is not set in the environment.
|
||||
``SUPABASE_CONNECTION_STRING`` is not set in the environment.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -36,8 +36,8 @@ from tradingagents.portfolio.supabase_client import SupabaseClient
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
requires_supabase = pytest.mark.skipif(
|
||||
not os.getenv("SUPABASE_URL"),
|
||||
reason="SUPABASE_URL not set — skipping Supabase integration tests",
|
||||
not os.getenv("SUPABASE_CONNECTION_STRING"),
|
||||
reason="SUPABASE_CONNECTION_STRING not set -- skipping integration tests",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,59 @@
|
|||
"""Tests for tradingagents/portfolio/repository.py.
|
||||
|
||||
Tests the PortfolioRepository façade — business logic for holdings management,
|
||||
cash accounting, avg-cost-basis updates, and snapshot creation.
|
||||
|
||||
Supabase integration tests are automatically skipped when ``SUPABASE_URL`` is
|
||||
not set in the environment (use the ``requires_supabase`` fixture marker).
|
||||
|
||||
Unit tests use ``mock_supabase_client`` to avoid DB access.
|
||||
Integration tests auto-skip when ``SUPABASE_CONNECTION_STRING`` is unset.
|
||||
|
||||
Run (unit tests only)::
|
||||
|
||||
pytest tests/portfolio/test_repository.py -v -k "not integration"
|
||||
|
||||
Run (with Supabase)::
|
||||
|
||||
SUPABASE_URL=... SUPABASE_KEY=... pytest tests/portfolio/test_repository.py -v
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.portfolio.exceptions import (
|
||||
HoldingNotFoundError,
|
||||
InsufficientCashError,
|
||||
InsufficientSharesError,
|
||||
)
|
||||
from tradingagents.portfolio.models import Holding, Portfolio, Trade
|
||||
from tradingagents.portfolio.repository import PortfolioRepository
|
||||
from tests.portfolio.conftest import requires_supabase
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_repo(mock_client, report_store):
|
||||
"""Build a PortfolioRepository with mock client and real report store."""
|
||||
return PortfolioRepository(
|
||||
client=mock_client,
|
||||
store=report_store,
|
||||
config={"data_dir": "reports", "max_positions": 15,
|
||||
"max_position_pct": 0.15, "max_sector_pct": 0.35,
|
||||
"min_cash_pct": 0.05, "default_budget": 100_000.0,
|
||||
"supabase_connection_string": ""},
|
||||
)
|
||||
|
||||
|
||||
def _mock_portfolio(pid="pid-1", cash=10_000.0):
|
||||
return Portfolio(
|
||||
portfolio_id=pid, name="Test", cash=cash,
|
||||
initial_cash=100_000.0, currency="USD",
|
||||
)
|
||||
|
||||
|
||||
def _mock_holding(pid="pid-1", ticker="AAPL", shares=50.0, avg_cost=190.0):
|
||||
return Holding(
|
||||
holding_id="hid-1", portfolio_id=pid, ticker=ticker,
|
||||
shares=shares, avg_cost=avg_cost,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# add_holding — new position
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -31,12 +61,19 @@ from tests.portfolio.conftest import requires_supabase
|
|||
|
||||
def test_add_holding_new_position(mock_supabase_client, report_store):
|
||||
"""add_holding() on a ticker not yet held must create a new Holding."""
|
||||
# TODO: implement
|
||||
# repo = PortfolioRepository(client=mock_supabase_client, store=report_store)
|
||||
# mock portfolio with enough cash
|
||||
# repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0)
|
||||
# assert mock_supabase_client.upsert_holding.called
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0)
|
||||
mock_supabase_client.get_holding.return_value = None
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
holding = repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
|
||||
|
||||
assert holding.ticker == "AAPL"
|
||||
assert holding.shares == 10
|
||||
assert holding.avg_cost == 200.0
|
||||
assert mock_supabase_client.upsert_holding.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -45,16 +82,21 @@ def test_add_holding_new_position(mock_supabase_client, report_store):
|
|||
|
||||
|
||||
def test_add_holding_updates_avg_cost(mock_supabase_client, report_store):
|
||||
"""add_holding() on an existing position must update avg_cost correctly.
|
||||
"""add_holding() on existing position must update avg_cost correctly."""
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0)
|
||||
existing = _mock_holding(shares=50.0, avg_cost=190.0)
|
||||
mock_supabase_client.get_holding.return_value = existing
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
Formula: new_avg_cost = (old_shares * old_avg_cost + new_shares * price)
|
||||
/ (old_shares + new_shares)
|
||||
"""
|
||||
# TODO: implement
|
||||
# existing holding: 50 shares @ 190.0
|
||||
# buy 25 more @ 200.0
|
||||
# expected avg_cost = (50*190 + 25*200) / 75 = 193.33...
|
||||
raise NotImplementedError
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
holding = repo.add_holding("pid-1", "AAPL", shares=25, price=200.0)
|
||||
|
||||
# expected: (50*190 + 25*200) / 75 = 193.333...
|
||||
expected_avg = (50 * 190.0 + 25 * 200.0) / 75
|
||||
assert holding.shares == 75
|
||||
assert holding.avg_cost == pytest.approx(expected_avg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -64,11 +106,11 @@ def test_add_holding_updates_avg_cost(mock_supabase_client, report_store):
|
|||
|
||||
def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store):
|
||||
"""add_holding() must raise InsufficientCashError when cash < shares * price."""
|
||||
# TODO: implement
|
||||
# portfolio with cash=500.0, try to buy 10 shares @ 200.0 (cost=2000)
|
||||
# with pytest.raises(InsufficientCashError):
|
||||
# repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0)
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=500.0)
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
with pytest.raises(InsufficientCashError):
|
||||
repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -78,11 +120,17 @@ def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store
|
|||
|
||||
def test_remove_holding_full_position(mock_supabase_client, report_store):
|
||||
"""remove_holding() selling all shares must delete the holding row."""
|
||||
# TODO: implement
|
||||
# holding: 50 shares
|
||||
# sell 50 shares → holding deleted, cash credited
|
||||
# assert mock_supabase_client.delete_holding.called
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0)
|
||||
mock_supabase_client.delete_holding.return_value = None
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
result = repo.remove_holding("pid-1", "AAPL", shares=50.0, price=200.0)
|
||||
|
||||
assert result is None
|
||||
assert mock_supabase_client.delete_holding.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -92,10 +140,17 @@ def test_remove_holding_full_position(mock_supabase_client, report_store):
|
|||
|
||||
def test_remove_holding_partial_position(mock_supabase_client, report_store):
|
||||
"""remove_holding() selling a subset must reduce shares, not delete."""
|
||||
# TODO: implement
|
||||
# holding: 50 shares
|
||||
# sell 20 → holding.shares == 30, avg_cost unchanged
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0)
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
result = repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
|
||||
|
||||
assert result is not None
|
||||
assert result.shares == 30.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -105,16 +160,20 @@ def test_remove_holding_partial_position(mock_supabase_client, report_store):
|
|||
|
||||
def test_remove_holding_raises_insufficient_shares(mock_supabase_client, report_store):
|
||||
"""remove_holding() must raise InsufficientSharesError when shares > held."""
|
||||
# TODO: implement
|
||||
# holding: 10 shares
|
||||
# try sell 20 → InsufficientSharesError
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_holding.return_value = _mock_holding(shares=10.0)
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
with pytest.raises(InsufficientSharesError):
|
||||
repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
|
||||
|
||||
|
||||
def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report_store):
|
||||
"""remove_holding() must raise HoldingNotFoundError for unknown tickers."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_holding.return_value = None
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
with pytest.raises(HoldingNotFoundError):
|
||||
repo.remove_holding("pid-1", "ZZZZ", shares=10.0, price=100.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -124,15 +183,35 @@ def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report
|
|||
|
||||
def test_add_holding_deducts_cash(mock_supabase_client, report_store):
|
||||
"""add_holding() must reduce portfolio.cash by shares * price."""
|
||||
# TODO: implement
|
||||
# portfolio.cash = 10_000, buy 10 @ 200 → cash should be 8_000
|
||||
raise NotImplementedError
|
||||
portfolio = _mock_portfolio(cash=10_000.0)
|
||||
mock_supabase_client.get_portfolio.return_value = portfolio
|
||||
mock_supabase_client.get_holding.return_value = None
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
|
||||
|
||||
# Check the portfolio passed to update_portfolio had cash deducted
|
||||
updated = mock_supabase_client.update_portfolio.call_args[0][0]
|
||||
assert updated.cash == pytest.approx(8_000.0)
|
||||
|
||||
|
||||
def test_remove_holding_credits_cash(mock_supabase_client, report_store):
|
||||
"""remove_holding() must increase portfolio.cash by shares * price."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
portfolio = _mock_portfolio(cash=5_000.0)
|
||||
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
|
||||
mock_supabase_client.get_portfolio.return_value = portfolio
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
|
||||
|
||||
updated = mock_supabase_client.update_portfolio.call_args[0][0]
|
||||
assert updated.cash == pytest.approx(9_000.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -142,14 +221,37 @@ def test_remove_holding_credits_cash(mock_supabase_client, report_store):
|
|||
|
||||
def test_add_holding_records_buy_trade(mock_supabase_client, report_store):
|
||||
"""add_holding() must call client.record_trade() with action='BUY'."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0)
|
||||
mock_supabase_client.get_holding.return_value = None
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
repo.add_holding("pid-1", "AAPL", shares=10, price=200.0)
|
||||
|
||||
trade = mock_supabase_client.record_trade.call_args[0][0]
|
||||
assert trade.action == "BUY"
|
||||
assert trade.ticker == "AAPL"
|
||||
assert trade.shares == 10
|
||||
assert trade.total_value == pytest.approx(2_000.0)
|
||||
|
||||
|
||||
def test_remove_holding_records_sell_trade(mock_supabase_client, report_store):
|
||||
"""remove_holding() must call client.record_trade() with action='SELL'."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0)
|
||||
mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0)
|
||||
mock_supabase_client.upsert_holding.side_effect = lambda h: h
|
||||
mock_supabase_client.update_portfolio.side_effect = lambda p: p
|
||||
mock_supabase_client.record_trade.side_effect = lambda t: t
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0)
|
||||
|
||||
trade = mock_supabase_client.record_trade.call_args[0][0]
|
||||
assert trade.action == "SELL"
|
||||
assert trade.ticker == "AAPL"
|
||||
assert trade.shares == 20.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -159,40 +261,120 @@ def test_remove_holding_records_sell_trade(mock_supabase_client, report_store):
|
|||
|
||||
def test_take_snapshot(mock_supabase_client, report_store):
|
||||
"""take_snapshot() must enrich holdings and persist a PortfolioSnapshot."""
|
||||
# TODO: implement
|
||||
# assert mock_supabase_client.save_snapshot.called
|
||||
# snapshot.total_value == cash + equity
|
||||
raise NotImplementedError
|
||||
portfolio = _mock_portfolio(cash=5_000.0)
|
||||
holding = _mock_holding(shares=50.0, avg_cost=190.0)
|
||||
mock_supabase_client.get_portfolio.return_value = portfolio
|
||||
mock_supabase_client.list_holdings.return_value = [holding]
|
||||
mock_supabase_client.save_snapshot.side_effect = lambda s: s
|
||||
|
||||
repo = _make_repo(mock_supabase_client, report_store)
|
||||
snapshot = repo.take_snapshot("pid-1", prices={"AAPL": 200.0})
|
||||
|
||||
assert mock_supabase_client.save_snapshot.called
|
||||
assert snapshot.cash == 5_000.0
|
||||
assert snapshot.num_positions == 1
|
||||
assert snapshot.total_value == pytest.approx(5_000.0 + 50.0 * 200.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supabase integration tests (auto-skip without SUPABASE_URL)
|
||||
# Supabase integration tests (auto-skip without SUPABASE_CONNECTION_STRING)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@requires_supabase
|
||||
def test_integration_create_and_get_portfolio():
|
||||
"""Integration: create a portfolio, retrieve it, verify fields match."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||
client = SupabaseClient.get_instance()
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
store = ReportStore()
|
||||
|
||||
repo = PortfolioRepository(client=client, store=store)
|
||||
portfolio = repo.create_portfolio("Integration Test", initial_cash=50_000.0)
|
||||
try:
|
||||
fetched = repo.get_portfolio(portfolio.portfolio_id)
|
||||
assert fetched.name == "Integration Test"
|
||||
assert fetched.cash == pytest.approx(50_000.0)
|
||||
assert fetched.initial_cash == pytest.approx(50_000.0)
|
||||
finally:
|
||||
client.delete_portfolio(portfolio.portfolio_id)
|
||||
|
||||
|
||||
@requires_supabase
|
||||
def test_integration_add_and_remove_holding():
|
||||
"""Integration: add holding, verify DB row; remove, verify deletion."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Integration: add holding, verify; remove, verify deletion."""
|
||||
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||
client = SupabaseClient.get_instance()
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
store = ReportStore()
|
||||
|
||||
repo = PortfolioRepository(client=client, store=store)
|
||||
portfolio = repo.create_portfolio("Hold Test", initial_cash=50_000.0)
|
||||
try:
|
||||
holding = repo.add_holding(
|
||||
portfolio.portfolio_id, "AAPL", shares=10, price=200.0,
|
||||
sector="Technology",
|
||||
)
|
||||
assert holding.ticker == "AAPL"
|
||||
assert holding.shares == 10
|
||||
|
||||
# Verify cash deducted
|
||||
p = repo.get_portfolio(portfolio.portfolio_id)
|
||||
assert p.cash == pytest.approx(48_000.0)
|
||||
|
||||
# Sell all
|
||||
result = repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=10, price=210.0)
|
||||
assert result is None
|
||||
|
||||
# Verify cash credited
|
||||
p = repo.get_portfolio(portfolio.portfolio_id)
|
||||
assert p.cash == pytest.approx(50_100.0)
|
||||
finally:
|
||||
client.delete_portfolio(portfolio.portfolio_id)
|
||||
|
||||
|
||||
@requires_supabase
|
||||
def test_integration_record_and_list_trades():
|
||||
"""Integration: record BUY + SELL trades, list them, verify order."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Integration: trades are recorded automatically via add/remove holding."""
|
||||
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||
client = SupabaseClient.get_instance()
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
store = ReportStore()
|
||||
|
||||
repo = PortfolioRepository(client=client, store=store)
|
||||
portfolio = repo.create_portfolio("Trade Test", initial_cash=50_000.0)
|
||||
try:
|
||||
repo.add_holding(portfolio.portfolio_id, "MSFT", shares=5, price=400.0)
|
||||
repo.remove_holding(portfolio.portfolio_id, "MSFT", shares=5, price=410.0)
|
||||
|
||||
trades = client.list_trades(portfolio.portfolio_id)
|
||||
assert len(trades) == 2
|
||||
assert trades[0].action == "SELL" # newest first
|
||||
assert trades[1].action == "BUY"
|
||||
finally:
|
||||
client.delete_portfolio(portfolio.portfolio_id)
|
||||
|
||||
|
||||
@requires_supabase
|
||||
def test_integration_save_and_load_snapshot():
|
||||
"""Integration: take snapshot, retrieve latest, verify total_value."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||
client = SupabaseClient.get_instance()
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
store = ReportStore()
|
||||
|
||||
repo = PortfolioRepository(client=client, store=store)
|
||||
portfolio = repo.create_portfolio("Snap Test", initial_cash=50_000.0)
|
||||
try:
|
||||
repo.add_holding(portfolio.portfolio_id, "AAPL", shares=10, price=200.0)
|
||||
snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 210.0})
|
||||
|
||||
assert snapshot.num_positions == 1
|
||||
assert snapshot.cash == pytest.approx(48_000.0)
|
||||
assert snapshot.total_value == pytest.approx(48_000.0 + 10 * 210.0)
|
||||
|
||||
latest = client.get_latest_snapshot(portfolio.portfolio_id)
|
||||
assert latest is not None
|
||||
assert latest.snapshot_id == snapshot.snapshot_id
|
||||
finally:
|
||||
client.delete_portfolio(portfolio.portfolio_id)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""Portfolio Manager configuration.
|
||||
|
||||
Reads all portfolio-related settings from environment variables with sensible
|
||||
defaults. Integrates with the existing ``tradingagents/default_config.py``
|
||||
pattern.
|
||||
Integrates with the existing ``tradingagents/default_config.py`` pattern,
|
||||
reading all portfolio settings from ``TRADINGAGENTS_<KEY>`` env vars.
|
||||
|
||||
Usage::
|
||||
|
||||
|
|
@ -11,34 +10,51 @@ Usage::
|
|||
cfg = get_portfolio_config()
|
||||
validate_config(cfg)
|
||||
print(cfg["max_positions"]) # 15
|
||||
|
||||
Environment variables (all optional):
|
||||
|
||||
SUPABASE_URL Supabase project URL (default: "")
|
||||
SUPABASE_KEY Supabase anon/service role key (default: "")
|
||||
PORTFOLIO_DATA_DIR Root dir for filesystem reports (default: "reports")
|
||||
PM_MAX_POSITIONS Max open positions (default: 15)
|
||||
PM_MAX_POSITION_PCT Max single-position weight 0–1 (default: 0.15)
|
||||
PM_MAX_SECTOR_PCT Max sector weight 0–1 (default: 0.35)
|
||||
PM_MIN_CASH_PCT Minimum cash reserve 0–1 (default: 0.05)
|
||||
PM_DEFAULT_BUDGET Default starting cash in USD (default: 100000.0)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _env(key: str, default=None):
|
||||
"""Read ``TRADINGAGENTS_<KEY>`` from the environment.
|
||||
|
||||
Matches the convention in ``tradingagents/default_config.py``.
|
||||
"""
|
||||
val = os.getenv(f"TRADINGAGENTS_{key.upper()}")
|
||||
if not val:
|
||||
return default
|
||||
return val
|
||||
|
||||
|
||||
def _env_float(key: str, default=None):
|
||||
val = _env(key)
|
||||
if val is None:
|
||||
return default
|
||||
try:
|
||||
return float(val)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
def _env_int(key: str, default=None):
|
||||
val = _env(key)
|
||||
if val is None:
|
||||
return default
|
||||
try:
|
||||
return int(val)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
|
||||
PORTFOLIO_CONFIG: dict = {
|
||||
# Supabase connection
|
||||
"supabase_url": "",
|
||||
"supabase_key": "",
|
||||
# Filesystem report root (matches report_paths.py REPORTS_ROOT)
|
||||
"data_dir": "reports",
|
||||
# PM constraint defaults
|
||||
"supabase_connection_string": os.getenv("SUPABASE_CONNECTION_STRING", ""),
|
||||
"data_dir": _env("PORTFOLIO_DATA_DIR", "reports"),
|
||||
"max_positions": 15,
|
||||
"max_position_pct": 0.15,
|
||||
"max_sector_pct": 0.35,
|
||||
|
|
@ -47,41 +63,44 @@ PORTFOLIO_CONFIG: dict = {
|
|||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_portfolio_config() -> dict:
|
||||
"""Return the merged portfolio config (defaults overridden by env vars).
|
||||
|
||||
Reads ``SUPABASE_URL``, ``SUPABASE_KEY``, ``PORTFOLIO_DATA_DIR``,
|
||||
``PM_MAX_POSITIONS``, ``PM_MAX_POSITION_PCT``, ``PM_MAX_SECTOR_PCT``,
|
||||
``PM_MIN_CASH_PCT``, and ``PM_DEFAULT_BUDGET`` from the environment.
|
||||
|
||||
Returns:
|
||||
A dict with all portfolio configuration keys.
|
||||
"""
|
||||
# TODO: implement — merge PORTFOLIO_CONFIG with env var overrides
|
||||
raise NotImplementedError
|
||||
cfg = dict(PORTFOLIO_CONFIG)
|
||||
cfg["supabase_connection_string"] = os.getenv("SUPABASE_CONNECTION_STRING", cfg["supabase_connection_string"])
|
||||
cfg["data_dir"] = _env("PORTFOLIO_DATA_DIR", cfg["data_dir"])
|
||||
cfg["max_positions"] = _env_int("PM_MAX_POSITIONS", cfg["max_positions"])
|
||||
cfg["max_position_pct"] = _env_float("PM_MAX_POSITION_PCT", cfg["max_position_pct"])
|
||||
cfg["max_sector_pct"] = _env_float("PM_MAX_SECTOR_PCT", cfg["max_sector_pct"])
|
||||
cfg["min_cash_pct"] = _env_float("PM_MIN_CASH_PCT", cfg["min_cash_pct"])
|
||||
cfg["default_budget"] = _env_float("PM_DEFAULT_BUDGET", cfg["default_budget"])
|
||||
return cfg
|
||||
|
||||
|
||||
def validate_config(cfg: dict) -> None:
|
||||
"""Validate a portfolio config dict, raising ValueError on invalid values.
|
||||
|
||||
Checks:
|
||||
- ``max_positions >= 1``
|
||||
- ``0 < max_position_pct <= 1``
|
||||
- ``0 < max_sector_pct <= 1``
|
||||
- ``0 <= min_cash_pct < 1``
|
||||
- ``default_budget > 0``
|
||||
- ``min_cash_pct + max_position_pct <= 1`` (can always meet both constraints)
|
||||
|
||||
Args:
|
||||
cfg: Config dict as returned by ``get_portfolio_config()``.
|
||||
|
||||
Raises:
|
||||
ValueError: With a descriptive message on the first failed check.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
if cfg["max_positions"] < 1:
|
||||
raise ValueError(f"max_positions must be >= 1, got {cfg['max_positions']}")
|
||||
if not (0 < cfg["max_position_pct"] <= 1.0):
|
||||
raise ValueError(f"max_position_pct must be in (0, 1], got {cfg['max_position_pct']}")
|
||||
if not (0 < cfg["max_sector_pct"] <= 1.0):
|
||||
raise ValueError(f"max_sector_pct must be in (0, 1], got {cfg['max_sector_pct']}")
|
||||
if not (0 <= cfg["min_cash_pct"] < 1.0):
|
||||
raise ValueError(f"min_cash_pct must be in [0, 1), got {cfg['min_cash_pct']}")
|
||||
if cfg["default_budget"] <= 0:
|
||||
raise ValueError(f"default_budget must be > 0, got {cfg['default_budget']}")
|
||||
if cfg["min_cash_pct"] + cfg["max_position_pct"] > 1.0:
|
||||
raise ValueError(
|
||||
f"min_cash_pct ({cfg['min_cash_pct']}) + max_position_pct ({cfg['max_position_pct']}) "
|
||||
f"must be <= 1.0"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class Portfolio:
|
|||
h.current_value for h in holdings if h.current_value is not None
|
||||
)
|
||||
self.total_value = self.cash + self.equity_value
|
||||
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0
|
||||
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 1.0
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,38 +1,28 @@
|
|||
"""Unified data-access façade for the Portfolio Manager.
|
||||
"""Unified data-access facade for the Portfolio Manager.
|
||||
|
||||
``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and
|
||||
``ReportStore`` (filesystem documents) into a single, business-logic-aware
|
||||
interface.
|
||||
|
||||
Callers should **only** interact with ``PortfolioRepository`` — do not use
|
||||
``SupabaseClient`` or ``ReportStore`` directly from outside this package.
|
||||
|
||||
Usage::
|
||||
|
||||
from tradingagents.portfolio import PortfolioRepository
|
||||
|
||||
repo = PortfolioRepository()
|
||||
|
||||
# Create a portfolio
|
||||
portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0)
|
||||
|
||||
# Buy shares
|
||||
holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50)
|
||||
|
||||
# Sell shares
|
||||
repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=25, price=200.00)
|
||||
|
||||
# Snapshot
|
||||
snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 200.00})
|
||||
|
||||
See ``docs/portfolio/04_repository_api.md`` for full API documentation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.portfolio.config import get_portfolio_config
|
||||
from tradingagents.portfolio.exceptions import (
|
||||
HoldingNotFoundError,
|
||||
InsufficientCashError,
|
||||
|
|
@ -49,7 +39,7 @@ from tradingagents.portfolio.supabase_client import SupabaseClient
|
|||
|
||||
|
||||
class PortfolioRepository:
|
||||
"""Unified façade over SupabaseClient and ReportStore.
|
||||
"""Unified facade over SupabaseClient and ReportStore.
|
||||
|
||||
Implements business logic for:
|
||||
- Average cost basis updates on repeated buys
|
||||
|
|
@ -64,16 +54,9 @@ class PortfolioRepository:
|
|||
store: ReportStore | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialise the repository.
|
||||
|
||||
Args:
|
||||
client: SupabaseClient instance. Uses ``SupabaseClient.get_instance()``
|
||||
when None.
|
||||
store: ReportStore instance. Creates a default instance when None.
|
||||
config: Portfolio config dict. Uses ``get_portfolio_config()`` when None.
|
||||
"""
|
||||
# TODO: implement — resolve defaults, store as self._client, self._store, self._cfg
|
||||
raise NotImplementedError
|
||||
self._cfg = config or get_portfolio_config()
|
||||
self._client = client or SupabaseClient.get_instance()
|
||||
self._store = store or ReportStore(base_dir=self._cfg["data_dir"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Portfolio lifecycle
|
||||
|
|
@ -85,54 +68,42 @@ class PortfolioRepository:
|
|||
initial_cash: float,
|
||||
currency: str = "USD",
|
||||
) -> Portfolio:
|
||||
"""Create a new portfolio with the given starting capital.
|
||||
|
||||
Generates a UUID for ``portfolio_id``. Sets ``cash = initial_cash``.
|
||||
|
||||
Args:
|
||||
name: Human-readable portfolio name.
|
||||
initial_cash: Starting capital in USD (or configured currency).
|
||||
currency: ISO 4217 currency code.
|
||||
|
||||
Returns:
|
||||
Persisted Portfolio instance.
|
||||
|
||||
Raises:
|
||||
DuplicatePortfolioError: If the name is already in use.
|
||||
ValueError: If ``initial_cash <= 0``.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Create a new portfolio with the given starting capital."""
|
||||
if initial_cash <= 0:
|
||||
raise ValueError(f"initial_cash must be > 0, got {initial_cash}")
|
||||
portfolio = Portfolio(
|
||||
portfolio_id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
cash=initial_cash,
|
||||
initial_cash=initial_cash,
|
||||
currency=currency,
|
||||
)
|
||||
return self._client.create_portfolio(portfolio)
|
||||
|
||||
def get_portfolio(self, portfolio_id: str) -> Portfolio:
|
||||
"""Fetch a portfolio by ID.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
|
||||
Raises:
|
||||
PortfolioNotFoundError: If not found.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Fetch a portfolio by ID."""
|
||||
return self._client.get_portfolio(portfolio_id)
|
||||
|
||||
def get_portfolio_with_holdings(
|
||||
self,
|
||||
portfolio_id: str,
|
||||
prices: dict[str, float] | None = None,
|
||||
) -> tuple[Portfolio, list[Holding]]:
|
||||
"""Fetch portfolio + all holdings, optionally enriched with current prices.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
prices: Optional ``{ticker: current_price}`` dict. When provided,
|
||||
holdings are enriched and ``Portfolio.enrich()`` is called.
|
||||
|
||||
Returns:
|
||||
``(Portfolio, list[Holding])`` — enriched when prices are supplied.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Fetch portfolio + all holdings, optionally enriched with current prices."""
|
||||
portfolio = self._client.get_portfolio(portfolio_id)
|
||||
holdings = self._client.list_holdings(portfolio_id)
|
||||
if prices:
|
||||
# First pass: compute equity for total_value
|
||||
equity = sum(
|
||||
prices.get(h.ticker, 0.0) * h.shares for h in holdings
|
||||
)
|
||||
total_value = portfolio.cash + equity
|
||||
# Second pass: enrich each holding with weight
|
||||
for h in holdings:
|
||||
if h.ticker in prices:
|
||||
h.enrich(prices[h.ticker], total_value)
|
||||
portfolio.enrich(holdings)
|
||||
return portfolio, holdings
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Holdings management
|
||||
|
|
@ -147,37 +118,65 @@ class PortfolioRepository:
|
|||
sector: str | None = None,
|
||||
industry: str | None = None,
|
||||
) -> Holding:
|
||||
"""Buy shares and update portfolio cash and holdings.
|
||||
"""Buy shares and update portfolio cash and holdings."""
|
||||
if shares <= 0:
|
||||
raise ValueError(f"shares must be > 0, got {shares}")
|
||||
if price <= 0:
|
||||
raise ValueError(f"price must be > 0, got {price}")
|
||||
|
||||
Business logic:
|
||||
- Raises ``InsufficientCashError`` if ``portfolio.cash < shares * price``
|
||||
- If holding already exists: updates ``avg_cost`` using weighted average
|
||||
- ``portfolio.cash -= shares * price``
|
||||
- Records a BUY trade automatically
|
||||
cost = shares * price
|
||||
portfolio = self._client.get_portfolio(portfolio_id)
|
||||
|
||||
Avg cost formula::
|
||||
if portfolio.cash < cost:
|
||||
raise InsufficientCashError(
|
||||
f"Need ${cost:.2f} but only ${portfolio.cash:.2f} available"
|
||||
)
|
||||
|
||||
new_avg_cost = (old_shares * old_avg_cost + new_shares * price)
|
||||
/ (old_shares + new_shares)
|
||||
# Check for existing holding to update avg cost
|
||||
existing = self._client.get_holding(portfolio_id, ticker)
|
||||
if existing:
|
||||
new_total_shares = existing.shares + shares
|
||||
new_avg_cost = (
|
||||
(existing.shares * existing.avg_cost + shares * price) / new_total_shares
|
||||
)
|
||||
existing.shares = new_total_shares
|
||||
existing.avg_cost = new_avg_cost
|
||||
if sector:
|
||||
existing.sector = sector
|
||||
if industry:
|
||||
existing.industry = industry
|
||||
holding = self._client.upsert_holding(existing)
|
||||
else:
|
||||
holding = Holding(
|
||||
holding_id=str(uuid.uuid4()),
|
||||
portfolio_id=portfolio_id,
|
||||
ticker=ticker.upper(),
|
||||
shares=shares,
|
||||
avg_cost=price,
|
||||
sector=sector,
|
||||
industry=industry,
|
||||
)
|
||||
holding = self._client.upsert_holding(holding)
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
ticker: Ticker symbol.
|
||||
shares: Number of shares to buy (must be > 0).
|
||||
price: Execution price per share.
|
||||
sector: Optional GICS sector name.
|
||||
industry: Optional GICS industry name.
|
||||
# Deduct cash
|
||||
portfolio.cash -= cost
|
||||
self._client.update_portfolio(portfolio)
|
||||
|
||||
Returns:
|
||||
Updated or created Holding.
|
||||
# Record trade
|
||||
trade = Trade(
|
||||
trade_id=str(uuid.uuid4()),
|
||||
portfolio_id=portfolio_id,
|
||||
ticker=ticker.upper(),
|
||||
action="BUY",
|
||||
shares=shares,
|
||||
price=price,
|
||||
total_value=cost,
|
||||
trade_date=datetime.now(timezone.utc).isoformat(),
|
||||
signal_source="pm_agent",
|
||||
)
|
||||
self._client.record_trade(trade)
|
||||
|
||||
Raises:
|
||||
InsufficientCashError: If cash would go negative.
|
||||
PortfolioNotFoundError: If portfolio_id does not exist.
|
||||
ValueError: If shares <= 0 or price <= 0.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return holding
|
||||
|
||||
def remove_holding(
|
||||
self,
|
||||
|
|
@ -186,33 +185,53 @@ class PortfolioRepository:
|
|||
shares: float,
|
||||
price: float,
|
||||
) -> Holding | None:
|
||||
"""Sell shares and update portfolio cash and holdings.
|
||||
"""Sell shares and update portfolio cash and holdings."""
|
||||
if shares <= 0:
|
||||
raise ValueError(f"shares must be > 0, got {shares}")
|
||||
if price <= 0:
|
||||
raise ValueError(f"price must be > 0, got {price}")
|
||||
|
||||
Business logic:
|
||||
- Raises ``HoldingNotFoundError`` if no holding exists for ticker
|
||||
- Raises ``InsufficientSharesError`` if ``holding.shares < shares``
|
||||
- If ``shares == holding.shares``: deletes the holding row, returns None
|
||||
- Otherwise: decrements ``holding.shares`` (avg_cost unchanged on sell)
|
||||
- ``portfolio.cash += shares * price``
|
||||
- Records a SELL trade automatically
|
||||
existing = self._client.get_holding(portfolio_id, ticker)
|
||||
if not existing:
|
||||
raise HoldingNotFoundError(
|
||||
f"No holding for {ticker} in portfolio {portfolio_id}"
|
||||
)
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
ticker: Ticker symbol.
|
||||
shares: Number of shares to sell (must be > 0).
|
||||
price: Execution price per share.
|
||||
if existing.shares < shares:
|
||||
raise InsufficientSharesError(
|
||||
f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}"
|
||||
)
|
||||
|
||||
Returns:
|
||||
Updated Holding, or None if the position was fully closed.
|
||||
proceeds = shares * price
|
||||
portfolio = self._client.get_portfolio(portfolio_id)
|
||||
|
||||
Raises:
|
||||
HoldingNotFoundError: If no holding exists for this ticker.
|
||||
InsufficientSharesError: If holding.shares < shares.
|
||||
PortfolioNotFoundError: If portfolio_id does not exist.
|
||||
ValueError: If shares <= 0 or price <= 0.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
if existing.shares == shares:
|
||||
# Full sell — delete holding
|
||||
self._client.delete_holding(portfolio_id, ticker)
|
||||
result = None
|
||||
else:
|
||||
existing.shares -= shares
|
||||
result = self._client.upsert_holding(existing)
|
||||
|
||||
# Credit cash
|
||||
portfolio.cash += proceeds
|
||||
self._client.update_portfolio(portfolio)
|
||||
|
||||
# Record trade
|
||||
trade = Trade(
|
||||
trade_id=str(uuid.uuid4()),
|
||||
portfolio_id=portfolio_id,
|
||||
ticker=ticker.upper(),
|
||||
action="SELL",
|
||||
shares=shares,
|
||||
price=price,
|
||||
total_value=proceeds,
|
||||
trade_date=datetime.now(timezone.utc).isoformat(),
|
||||
signal_source="pm_agent",
|
||||
)
|
||||
self._client.record_trade(trade)
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshots
|
||||
|
|
@ -223,20 +242,19 @@ class PortfolioRepository:
|
|||
portfolio_id: str,
|
||||
prices: dict[str, float],
|
||||
) -> PortfolioSnapshot:
|
||||
"""Take an immutable snapshot of the current portfolio state.
|
||||
|
||||
Fetches all holdings, enriches them with ``prices``, computes
|
||||
``total_value``, then persists via ``SupabaseClient.save_snapshot()``.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
prices: ``{ticker: current_price}`` for all held tickers.
|
||||
|
||||
Returns:
|
||||
Persisted PortfolioSnapshot.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Take an immutable snapshot of the current portfolio state."""
|
||||
portfolio, holdings = self.get_portfolio_with_holdings(portfolio_id, prices)
|
||||
snapshot = PortfolioSnapshot(
|
||||
snapshot_id=str(uuid.uuid4()),
|
||||
portfolio_id=portfolio_id,
|
||||
snapshot_date=datetime.now(timezone.utc).isoformat(),
|
||||
total_value=portfolio.total_value or portfolio.cash,
|
||||
cash=portfolio.cash,
|
||||
equity_value=portfolio.equity_value or 0.0,
|
||||
num_positions=len(holdings),
|
||||
holdings_snapshot=[h.to_dict() for h in holdings],
|
||||
)
|
||||
return self._client.save_snapshot(snapshot)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Report convenience methods
|
||||
|
|
@ -249,37 +267,21 @@ class PortfolioRepository:
|
|||
decision: dict[str, Any],
|
||||
markdown: str | None = None,
|
||||
) -> Path:
|
||||
"""Save a PM agent decision and update portfolio.report_path.
|
||||
|
||||
Delegates to ``ReportStore.save_pm_decision()`` then updates the
|
||||
``portfolio.report_path`` column in Supabase to point to the daily
|
||||
portfolio directory.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
date: ISO date string, e.g. ``"2026-03-20"``.
|
||||
decision: PM decision dict.
|
||||
markdown: Optional human-readable markdown version.
|
||||
|
||||
Returns:
|
||||
Path of the written JSON file.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Save a PM agent decision and update portfolio.report_path."""
|
||||
path = self._store.save_pm_decision(date, portfolio_id, decision, markdown)
|
||||
# Update portfolio report_path
|
||||
portfolio = self._client.get_portfolio(portfolio_id)
|
||||
portfolio.report_path = str(self._store._portfolio_dir(date))
|
||||
self._client.update_portfolio(portfolio)
|
||||
return path
|
||||
|
||||
def load_pm_decision(
|
||||
self,
|
||||
portfolio_id: str,
|
||||
date: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Load a PM decision JSON. Returns None if not found.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
date: ISO date string.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Load a PM decision JSON. Returns None if not found."""
|
||||
return self._store.load_pm_decision(date, portfolio_id)
|
||||
|
||||
def save_risk_metrics(
|
||||
self,
|
||||
|
|
@ -287,29 +289,13 @@ class PortfolioRepository:
|
|||
date: str,
|
||||
metrics: dict[str, Any],
|
||||
) -> Path:
|
||||
"""Save risk computation results. Delegates to ReportStore.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
date: ISO date string.
|
||||
metrics: Risk metrics dict (Sharpe, Sortino, VaR, beta, etc.).
|
||||
|
||||
Returns:
|
||||
Path of the written file.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Save risk computation results."""
|
||||
return self._store.save_risk_metrics(date, portfolio_id, metrics)
|
||||
|
||||
def load_risk_metrics(
|
||||
self,
|
||||
portfolio_id: str,
|
||||
date: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Load risk metrics. Returns None if not found.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
date: ISO date string.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Load risk metrics. Returns None if not found."""
|
||||
return self._store.load_risk_metrics(date, portfolio_id)
|
||||
|
|
|
|||
|
|
@ -1,34 +1,30 @@
|
|||
"""Supabase database client for the Portfolio Manager.
|
||||
"""PostgreSQL database client for the Portfolio Manager.
|
||||
|
||||
Thin wrapper around ``supabase-py`` (no ORM) that:
|
||||
- Provides a singleton connection (one client per process)
|
||||
- Translates Supabase / HTTP errors into domain exceptions
|
||||
- Converts raw DB rows into typed model instances via ``Model.from_dict()``
|
||||
|
||||
**No ORM is used here by design** — see
|
||||
``docs/agent/decisions/012-portfolio-no-orm.md`` for the full rationale.
|
||||
In short: ``supabase-py``'s builder-pattern API is sufficient for 4 tables;
|
||||
Prisma and SQLAlchemy add build-step / runtime complexity that isn't warranted
|
||||
for this non-DB-heavy feature.
|
||||
Uses ``psycopg2`` with the ``SUPABASE_CONNECTION_STRING`` env var to talk
|
||||
directly to the Supabase-hosted PostgreSQL database. No ORM — see
|
||||
``docs/agent/decisions/012-portfolio-no-orm.md`` for rationale.
|
||||
|
||||
Usage::
|
||||
|
||||
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||
from tradingagents.portfolio.models import Portfolio
|
||||
|
||||
client = SupabaseClient.get_instance()
|
||||
portfolio = client.get_portfolio("some-uuid")
|
||||
|
||||
Configuration (read from environment via ``get_portfolio_config()``):
|
||||
SUPABASE_URL — Supabase project URL
|
||||
SUPABASE_KEY — Supabase anon or service-role key
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
|
||||
from tradingagents.portfolio.config import get_portfolio_config
|
||||
from tradingagents.portfolio.exceptions import (
|
||||
DuplicatePortfolioError,
|
||||
HoldingNotFoundError,
|
||||
PortfolioError,
|
||||
PortfolioNotFoundError,
|
||||
)
|
||||
from tradingagents.portfolio.models import (
|
||||
|
|
@ -40,167 +36,213 @@ from tradingagents.portfolio.models import (
|
|||
|
||||
|
||||
class SupabaseClient:
|
||||
"""Singleton Supabase CRUD client for portfolio data.
|
||||
"""Singleton PostgreSQL CRUD client for portfolio data.
|
||||
|
||||
All public methods translate Supabase / HTTP errors into domain exceptions
|
||||
All public methods translate database errors into domain exceptions
|
||||
and return typed model instances.
|
||||
|
||||
Do not instantiate directly — use ``SupabaseClient.get_instance()``.
|
||||
"""
|
||||
|
||||
_instance: "SupabaseClient | None" = None
|
||||
_instance: SupabaseClient | None = None
|
||||
|
||||
def __init__(self, url: str, key: str) -> None:
|
||||
"""Initialise the Supabase client.
|
||||
def __init__(self, connection_string: str) -> None:
|
||||
self._dsn = self._fix_dsn(connection_string)
|
||||
self._conn = psycopg2.connect(self._dsn)
|
||||
self._conn.autocommit = True
|
||||
|
||||
Args:
|
||||
url: Supabase project URL.
|
||||
key: Supabase anon or service-role key.
|
||||
"""
|
||||
# TODO: implement — create supabase.create_client(url, key)
|
||||
raise NotImplementedError
|
||||
@staticmethod
|
||||
def _fix_dsn(dsn: str) -> str:
|
||||
"""URL-encode the password if it contains special characters."""
|
||||
from urllib.parse import quote
|
||||
if "://" not in dsn:
|
||||
return dsn # already key=value format
|
||||
scheme, rest = dsn.split("://", 1)
|
||||
at_idx = rest.rfind("@")
|
||||
if at_idx == -1:
|
||||
return dsn
|
||||
userinfo = rest[:at_idx]
|
||||
hostinfo = rest[at_idx + 1:]
|
||||
colon_idx = userinfo.find(":")
|
||||
if colon_idx == -1:
|
||||
return dsn
|
||||
user = userinfo[:colon_idx]
|
||||
password = userinfo[colon_idx + 1:]
|
||||
encoded = quote(password, safe="")
|
||||
return f"{scheme}://{user}:{encoded}@{hostinfo}"
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "SupabaseClient":
|
||||
"""Return the singleton instance, creating it if necessary.
|
||||
def get_instance(cls) -> SupabaseClient:
|
||||
"""Return the singleton instance, creating it if necessary."""
|
||||
if cls._instance is None:
|
||||
cfg = get_portfolio_config()
|
||||
dsn = cfg["supabase_connection_string"]
|
||||
if not dsn:
|
||||
raise PortfolioError(
|
||||
"SUPABASE_CONNECTION_STRING not configured. "
|
||||
"Set it in .env or as an environment variable."
|
||||
)
|
||||
cls._instance = cls(dsn)
|
||||
return cls._instance
|
||||
|
||||
Reads SUPABASE_URL and SUPABASE_KEY from ``get_portfolio_config()``.
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Close and reset the singleton (for testing)."""
|
||||
if cls._instance is not None:
|
||||
try:
|
||||
cls._instance._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
cls._instance = None
|
||||
|
||||
Raises:
|
||||
PortfolioError: If SUPABASE_URL or SUPABASE_KEY are not configured.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
def _cursor(self):
|
||||
"""Return a RealDictCursor."""
|
||||
return self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Portfolio CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_portfolio(self, portfolio: Portfolio) -> Portfolio:
|
||||
"""Insert a new portfolio row.
|
||||
|
||||
Args:
|
||||
portfolio: Portfolio instance with all required fields set.
|
||||
|
||||
Returns:
|
||||
Portfolio with DB-assigned timestamps.
|
||||
|
||||
Raises:
|
||||
DuplicatePortfolioError: If portfolio_id already exists.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Insert a new portfolio row."""
|
||||
pid = portfolio.portfolio_id or str(uuid.uuid4())
|
||||
try:
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO portfolios
|
||||
(portfolio_id, name, cash, initial_cash, currency, report_path, metadata)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING *""",
|
||||
(pid, portfolio.name, portfolio.cash, portfolio.initial_cash,
|
||||
portfolio.currency, portfolio.report_path,
|
||||
json.dumps(portfolio.metadata)),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
except psycopg2.errors.UniqueViolation as exc:
|
||||
raise DuplicatePortfolioError(f"Portfolio already exists: {pid}") from exc
|
||||
return self._row_to_portfolio(row)
|
||||
|
||||
def get_portfolio(self, portfolio_id: str) -> Portfolio:
|
||||
"""Fetch a portfolio by ID.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
|
||||
Returns:
|
||||
Portfolio instance.
|
||||
|
||||
Raises:
|
||||
PortfolioNotFoundError: If no portfolio has that ID.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Fetch a portfolio by ID."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute("SELECT * FROM portfolios WHERE portfolio_id = %s", (portfolio_id,))
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
|
||||
return self._row_to_portfolio(row)
|
||||
|
||||
def list_portfolios(self) -> list[Portfolio]:
|
||||
"""Return all portfolios ordered by created_at DESC."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
with self._cursor() as cur:
|
||||
cur.execute("SELECT * FROM portfolios ORDER BY created_at DESC")
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_portfolio(r) for r in rows]
|
||||
|
||||
def update_portfolio(self, portfolio: Portfolio) -> Portfolio:
|
||||
"""Update mutable portfolio fields (cash, report_path, metadata).
|
||||
|
||||
Args:
|
||||
portfolio: Portfolio with updated field values.
|
||||
|
||||
Returns:
|
||||
Updated Portfolio with refreshed updated_at.
|
||||
|
||||
Raises:
|
||||
PortfolioNotFoundError: If portfolio_id does not exist.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Update mutable portfolio fields (cash, report_path, metadata)."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""UPDATE portfolios
|
||||
SET cash = %s, report_path = %s, metadata = %s
|
||||
WHERE portfolio_id = %s
|
||||
RETURNING *""",
|
||||
(portfolio.cash, portfolio.report_path,
|
||||
json.dumps(portfolio.metadata), portfolio.portfolio_id),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio.portfolio_id}")
|
||||
return self._row_to_portfolio(row)
|
||||
|
||||
def delete_portfolio(self, portfolio_id: str) -> None:
|
||||
"""Delete a portfolio and all associated data (CASCADE).
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the portfolio to delete.
|
||||
|
||||
Raises:
|
||||
PortfolioNotFoundError: If portfolio_id does not exist.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Delete a portfolio and all associated data (CASCADE)."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM portfolios WHERE portfolio_id = %s RETURNING portfolio_id",
|
||||
(portfolio_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Holdings CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def upsert_holding(self, holding: Holding) -> Holding:
|
||||
"""Insert or update a holding row (upsert on portfolio_id + ticker).
|
||||
|
||||
Args:
|
||||
holding: Holding instance with all required fields set.
|
||||
|
||||
Returns:
|
||||
Holding with DB-assigned / refreshed timestamps.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Insert or update a holding row (upsert on portfolio_id + ticker)."""
|
||||
hid = holding.holding_id or str(uuid.uuid4())
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO holdings
|
||||
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
|
||||
DO UPDATE SET shares = EXCLUDED.shares,
|
||||
avg_cost = EXCLUDED.avg_cost,
|
||||
sector = EXCLUDED.sector,
|
||||
industry = EXCLUDED.industry
|
||||
RETURNING *""",
|
||||
(hid, holding.portfolio_id, holding.ticker.upper(),
|
||||
holding.shares, holding.avg_cost, holding.sector, holding.industry),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return self._row_to_holding(row)
|
||||
|
||||
def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None:
|
||||
"""Return the holding for (portfolio_id, ticker), or None if not found.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
ticker: Ticker symbol (case-insensitive, stored as uppercase).
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Return the holding for (portfolio_id, ticker), or None."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT * FROM holdings WHERE portfolio_id = %s AND ticker = %s",
|
||||
(portfolio_id, ticker.upper()),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return self._row_to_holding(row) if row else None
|
||||
|
||||
def list_holdings(self, portfolio_id: str) -> list[Holding]:
|
||||
"""Return all holdings for a portfolio ordered by cost_basis DESC.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Return all holdings for a portfolio ordered by cost_basis DESC."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""SELECT * FROM holdings
|
||||
WHERE portfolio_id = %s
|
||||
ORDER BY shares * avg_cost DESC""",
|
||||
(portfolio_id,),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_holding(r) for r in rows]
|
||||
|
||||
def delete_holding(self, portfolio_id: str, ticker: str) -> None:
|
||||
"""Delete the holding for (portfolio_id, ticker).
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
ticker: Ticker symbol.
|
||||
|
||||
Raises:
|
||||
HoldingNotFoundError: If no such holding exists.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Delete the holding for (portfolio_id, ticker)."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s RETURNING holding_id",
|
||||
(portfolio_id, ticker.upper()),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise HoldingNotFoundError(
|
||||
f"Holding not found: {ticker} in portfolio {portfolio_id}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trades
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def record_trade(self, trade: Trade) -> Trade:
|
||||
"""Insert a new trade record. Immutable — no update method.
|
||||
|
||||
Args:
|
||||
trade: Trade instance with all required fields set.
|
||||
|
||||
Returns:
|
||||
Trade with DB-assigned trade_id and trade_date.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Insert a new trade record."""
|
||||
tid = trade.trade_id or str(uuid.uuid4())
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO trades
|
||||
(trade_id, portfolio_id, ticker, action, shares, price,
|
||||
total_value, rationale, signal_source, metadata)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING *""",
|
||||
(tid, trade.portfolio_id, trade.ticker, trade.action,
|
||||
trade.shares, trade.price, trade.total_value,
|
||||
trade.rationale, trade.signal_source,
|
||||
json.dumps(trade.metadata)),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return self._row_to_trade(row)
|
||||
|
||||
def list_trades(
|
||||
self,
|
||||
|
|
@ -208,51 +250,142 @@ class SupabaseClient:
|
|||
ticker: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Trade]:
|
||||
"""Return recent trades for a portfolio, newest first.
|
||||
|
||||
Args:
|
||||
portfolio_id: Filter by portfolio.
|
||||
ticker: Optional additional filter by ticker symbol.
|
||||
limit: Maximum number of rows to return.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Return recent trades for a portfolio, newest first."""
|
||||
if ticker:
|
||||
query = """SELECT * FROM trades
|
||||
WHERE portfolio_id = %s AND ticker = %s
|
||||
ORDER BY trade_date DESC LIMIT %s"""
|
||||
params = (portfolio_id, ticker.upper(), limit)
|
||||
else:
|
||||
query = """SELECT * FROM trades
|
||||
WHERE portfolio_id = %s
|
||||
ORDER BY trade_date DESC LIMIT %s"""
|
||||
params = (portfolio_id, limit)
|
||||
with self._cursor() as cur:
|
||||
cur.execute(query, params)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_trade(r) for r in rows]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshots
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot:
|
||||
"""Insert a new immutable portfolio snapshot.
|
||||
|
||||
Args:
|
||||
snapshot: PortfolioSnapshot with all required fields set.
|
||||
|
||||
Returns:
|
||||
Snapshot with DB-assigned snapshot_id and snapshot_date.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Insert a new immutable portfolio snapshot."""
|
||||
sid = snapshot.snapshot_id or str(uuid.uuid4())
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""INSERT INTO snapshots
|
||||
(snapshot_id, portfolio_id, total_value, cash, equity_value,
|
||||
num_positions, holdings_snapshot, metadata)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
RETURNING *""",
|
||||
(sid, snapshot.portfolio_id, snapshot.total_value,
|
||||
snapshot.cash, snapshot.equity_value, snapshot.num_positions,
|
||||
json.dumps(snapshot.holdings_snapshot),
|
||||
json.dumps(snapshot.metadata)),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return self._row_to_snapshot(row)
|
||||
|
||||
def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None:
|
||||
"""Return the most recent snapshot for a portfolio, or None.
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
"""Return the most recent snapshot for a portfolio, or None."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""SELECT * FROM snapshots
|
||||
WHERE portfolio_id = %s
|
||||
ORDER BY snapshot_date DESC LIMIT 1""",
|
||||
(portfolio_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return self._row_to_snapshot(row) if row else None
|
||||
|
||||
def list_snapshots(
|
||||
self,
|
||||
portfolio_id: str,
|
||||
limit: int = 30,
|
||||
) -> list[PortfolioSnapshot]:
|
||||
"""Return snapshots newest-first up to limit.
|
||||
"""Return snapshots newest-first up to limit."""
|
||||
with self._cursor() as cur:
|
||||
cur.execute(
|
||||
"""SELECT * FROM snapshots
|
||||
WHERE portfolio_id = %s
|
||||
ORDER BY snapshot_date DESC LIMIT %s""",
|
||||
(portfolio_id, limit),
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_snapshot(r) for r in rows]
|
||||
|
||||
Args:
|
||||
portfolio_id: UUID of the target portfolio.
|
||||
limit: Maximum number of snapshots to return.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
# ------------------------------------------------------------------
|
||||
# Row -> Model helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _row_to_portfolio(row: dict) -> Portfolio:
|
||||
metadata = row.get("metadata") or {}
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
return Portfolio(
|
||||
portfolio_id=str(row["portfolio_id"]),
|
||||
name=row["name"],
|
||||
cash=float(row["cash"]),
|
||||
initial_cash=float(row["initial_cash"]),
|
||||
currency=row["currency"].strip(),
|
||||
created_at=str(row["created_at"]),
|
||||
updated_at=str(row["updated_at"]),
|
||||
report_path=row.get("report_path"),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_holding(row: dict) -> Holding:
|
||||
return Holding(
|
||||
holding_id=str(row["holding_id"]),
|
||||
portfolio_id=str(row["portfolio_id"]),
|
||||
ticker=row["ticker"],
|
||||
shares=float(row["shares"]),
|
||||
avg_cost=float(row["avg_cost"]),
|
||||
sector=row.get("sector"),
|
||||
industry=row.get("industry"),
|
||||
created_at=str(row["created_at"]),
|
||||
updated_at=str(row["updated_at"]),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_trade(row: dict) -> Trade:
|
||||
metadata = row.get("metadata") or {}
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
return Trade(
|
||||
trade_id=str(row["trade_id"]),
|
||||
portfolio_id=str(row["portfolio_id"]),
|
||||
ticker=row["ticker"],
|
||||
action=row["action"],
|
||||
shares=float(row["shares"]),
|
||||
price=float(row["price"]),
|
||||
total_value=float(row["total_value"]),
|
||||
trade_date=str(row["trade_date"]),
|
||||
rationale=row.get("rationale"),
|
||||
signal_source=row.get("signal_source"),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: dict) -> PortfolioSnapshot:
|
||||
holdings = row.get("holdings_snapshot") or []
|
||||
if isinstance(holdings, str):
|
||||
holdings = json.loads(holdings)
|
||||
metadata = row.get("metadata") or {}
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
return PortfolioSnapshot(
|
||||
snapshot_id=str(row["snapshot_id"]),
|
||||
portfolio_id=str(row["portfolio_id"]),
|
||||
snapshot_date=str(row["snapshot_date"]),
|
||||
total_value=float(row["total_value"]),
|
||||
cash=float(row["cash"]),
|
||||
equity_value=float(row["equity_value"]),
|
||||
num_positions=int(row["num_positions"]),
|
||||
holdings_snapshot=holdings,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue