From 29e75384cce0d509cbfde7d95b221f5c773a7080 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 10:28:36 +0000 Subject: [PATCH 1/5] Initial plan From f1cabe7a4af2865841399ba4896d741191f51abf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 10:40:48 +0000 Subject: [PATCH 2/5] =?UTF-8?q?feat:=20portfolio=20manager=20data=20founda?= =?UTF-8?q?tion=20=E2=80=94=20docs,=20SQL=20migration,=20and=20module=20sc?= =?UTF-8?q?affolding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> --- docs/portfolio/00_overview.md | 245 +++++++++ docs/portfolio/01_phase1_plan.md | 352 +++++++++++++ docs/portfolio/02_data_models.md | 230 ++++++++ docs/portfolio/03_database_schema.md | 305 +++++++++++ docs/portfolio/04_repository_api.md | 492 ++++++++++++++++++ tests/portfolio/__init__.py | 1 + tests/portfolio/conftest.py | 110 ++++ tests/portfolio/test_models.py | 132 +++++ tests/portfolio/test_report_store.py | 145 ++++++ tests/portfolio/test_repository.py | 198 +++++++ tradingagents/portfolio/__init__.py | 55 ++ tradingagents/portfolio/config.py | 87 ++++ tradingagents/portfolio/exceptions.py | 60 +++ .../migrations/001_initial_schema.sql | 161 ++++++ tradingagents/portfolio/models.py | 207 ++++++++ tradingagents/portfolio/report_store.py | 245 +++++++++ tradingagents/portfolio/repository.py | 315 +++++++++++ tradingagents/portfolio/supabase_client.py | 252 +++++++++ 18 files changed, 3592 insertions(+) create mode 100644 docs/portfolio/00_overview.md create mode 100644 docs/portfolio/01_phase1_plan.md create mode 100644 docs/portfolio/02_data_models.md create mode 100644 docs/portfolio/03_database_schema.md create mode 100644 docs/portfolio/04_repository_api.md create mode 100644 tests/portfolio/__init__.py create mode 100644 tests/portfolio/conftest.py create mode 100644 tests/portfolio/test_models.py create mode 100644 tests/portfolio/test_report_store.py create mode 100644 tests/portfolio/test_repository.py create mode 100644 tradingagents/portfolio/__init__.py create mode 100644 tradingagents/portfolio/config.py create mode 100644 tradingagents/portfolio/exceptions.py create mode 100644 tradingagents/portfolio/migrations/001_initial_schema.sql create mode 100644 tradingagents/portfolio/models.py create mode 100644 tradingagents/portfolio/report_store.py create mode 100644 tradingagents/portfolio/repository.py create mode 100644 tradingagents/portfolio/supabase_client.py diff --git a/docs/portfolio/00_overview.md b/docs/portfolio/00_overview.md new file mode 100644 index 00000000..a6f7ab46 --- /dev/null +++ b/docs/portfolio/00_overview.md @@ -0,0 +1,245 @@ +# Portfolio Manager Agent — Design Overview + + + +## Feature Description + +The Portfolio Manager Agent (PMA) is an autonomous agent that manages a simulated +investment portfolio end-to-end. It performs the following actions in sequence: + +1. **Initiates market research** — triggers the existing `ScannerGraph` to produce a + macro watchlist of top candidate tickers. +2. **Initiates per-ticker analysis** — feeds scan results into the existing + `MacroBridge` / `TradingAgentsGraph` pipeline for high-conviction candidates. +3. **Loads current holdings** — queries the Supabase database for the active portfolio + state (positions, cash balance, sector weights). +4. **Requests lightweight holding reviews** — for each existing holding, runs a + quick `HoldingReviewerAgent` (quick_think) that checks price action and recent + news — no full bull/bear debate needed. +5. **Computes portfolio-level risk metrics** — pure Python, no LLM: + Sharpe ratio, Sortino ratio, beta, 95 % VaR, max drawdown, sector concentration, + correlation matrix, and what-if buy/sell scenarios. +6. **Makes allocation decisions** — the Portfolio Manager Agent (deep_think + + memory) reads all inputs and outputs a structured JSON with sells, buys, holds, + target cash %, and detailed rationale. +7. **Executes mock trades** — validates decisions against constraints, records trades + in Supabase, updates holdings, and takes an immutable snapshot. + +--- + +## Architecture Decision: Supabase (PostgreSQL) + Filesystem + +``` +┌─────────────────────────────────────────────────┐ +│ Supabase (PostgreSQL) │ +│ │ +│ portfolios holdings trades snapshots │ +│ │ +│ "What do I own right now?" │ +│ "What trades did I make?" │ +│ "What was my portfolio value on date X?" │ +└────────────────────┬────────────────────────────┘ + │ report_path column + ▼ +┌─────────────────────────────────────────────────┐ +│ Filesystem (reports/) │ +│ │ +│ reports/daily/{date}/ │ +│ market/ ← scan output │ +│ {TICKER}/ ← per-ticker analysis │ +│ portfolio/ │ +│ holdings_review.json │ +│ risk_metrics.json │ +│ pm_decision.json │ +│ pm_decision.md (human-readable) │ +│ │ +│ "Why did I decide this?" │ +│ "What was the macro context?" │ +│ "What did the risk model say?" │ +└─────────────────────────────────────────────────┘ +``` + +**Rationale:** + +| Concern | Storage | Why | +|---------|---------|-----| +| Transactional integrity (trades) | Supabase | ACID, foreign keys, row-level security | +| Fast portfolio queries (weights, cash) | Supabase | SQL aggregations | +| LLM reports (large text, markdown) | Filesystem | Avoids bloating the DB | +| Agent memory / rationale | Filesystem | Easy to inspect and version | +| Audit trail of decisions | Filesystem | Markdown readable by humans | + +The `report_path` column in the `portfolios` table points to the daily portfolio +subdirectory on disk: `reports/daily/{date}/portfolio/`. + +--- + +## 5-Phase Workflow + +``` +┌────────────────────────────────────────────────────────────────────────────┐ +│ PHASE 1 (parallel) │ +│ │ +│ 1a. ScannerGraph.scan(date) 1b. Load Holdings + Fetch Prices │ +│ → macro_scan_summary.json → List[Holding] with │ +│ watchlist of top candidates current_price, current_value │ +└───────────────────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ PHASE 2 (parallel) │ +│ │ +│ 2a. New Candidate Analysis 2b. Holding Re-evaluation │ +│ MacroBridge.run_all_tickers() HoldingReviewerAgent (quick_think)│ +│ Full bull/bear pipeline per 7-day price + 3-day news │ +│ HIGH/MEDIUM conviction → JSON: signal/confidence/reason │ +│ candidates that are NOT urgency per holding │ +│ already held │ +└───────────────────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ PHASE 3 (Python, no LLM) │ +│ │ +│ Risk Metrics Computation │ +│ • Sharpe ratio (annualised, rf = 0) │ +│ • Sortino ratio (downside deviation) │ +│ • Portfolio beta (vs SPY) │ +│ • 95 % VaR (historical simulation, 30-day window) │ +│ • Max drawdown (peak-to-trough, 90-day window) │ +│ • Sector concentration (weight per GICS sector) │ +│ • Correlation matrix (all holdings) │ +│ • What-if scenarios (buy X, sell Y → new weights) │ +└───────────────────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ PHASE 4 Portfolio Manager Agent (deep_think + memory) │ +│ │ +│ Reads: macro context, holdings, candidate signals, re-eval signals, │ +│ risk metrics, budget constraint, past decisions (memory) │ +│ │ +│ Outputs structured JSON: │ +│ { │ +│ "sells": [{"ticker": "X", "shares": 10, "reason": "..."}], │ +│ "buys": [{"ticker": "Y", "shares": 5, "reason": "..."}], │ +│ "holds": ["Z"], │ +│ "target_cash_pct": 0.08, │ +│ "rationale": "...", │ +│ "risk_summary": "..." │ +│ } │ +└───────────────────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ PHASE 5 Trade Execution (Mock) │ +│ │ +│ • Validate decisions against constraints (position size, sector, cash) │ +│ • Record each trade in Supabase (trades table) │ +│ • Update holdings (avg cost basis, shares) │ +│ • Deduct / credit cash balance │ +│ • Take immutable portfolio snapshot │ +│ • Save PM decision + risk report to filesystem │ +└────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Agent Specifications + +### Portfolio Manager Agent (PMA) + +| Property | Value | +|----------|-------| +| LLM tier | `deep_think` | +| Memory | Enabled — reads previous PM decision files from filesystem | +| Output format | Structured JSON (validated before trade execution) | +| Invocation | Once per run, after Phases 1–3 | + +**Prompt inputs:** +- Macro scan summary (top candidates + context) +- Current holdings list (ticker, shares, avg cost, current price, weight, sector) +- Candidate analysis signals (BUY/SELL/HOLD per ticker from Phase 2a) +- Holding review signals (signal, confidence, reason, urgency per holding from Phase 2b) +- Risk metrics report (Phase 3 output) +- Budget constraint (available cash) +- Portfolio constraints (see below) +- Previous decision (last PM decision file for memory continuity) + +### Holding Reviewer Agent + +| Property | Value | +|----------|-------| +| LLM tier | `quick_think` | +| Memory | Disabled | +| Output format | Structured JSON | +| Tools | `get_stock_data` (7-day window), `get_news` (3-day window), RSI, MACD | +| Invocation | Once per existing holding (parallelisable) | + +**Output schema per holding:** +```json +{ + "ticker": "AAPL", + "signal": "HOLD", + "confidence": 0.72, + "reason": "Price action neutral; no material news. RSI 52, MACD flat.", + "urgency": "LOW" +} +``` + +--- + +## PM Agent Constraints + +These constraints are **hard limits** enforced during Phase 5 (trade execution). +The PM Agent is also instructed to respect them in its prompt. + +| Constraint | Value | +|------------|-------| +| Max position size | 15 % of portfolio value | +| Max sector exposure | 35 % of portfolio value | +| Min cash reserve | 5 % of portfolio value | +| Max number of positions | 15 | + +--- + +## PM Risk Management Rules + +These rules trigger specific actions and are part of the PM Agent's system prompt: + +| Trigger | Action | +|---------|--------| +| Portfolio beta > 1.3 | Reduce cyclical / high-beta positions | +| Sector exposure > 35 % | Diversify — sell smallest position in that sector | +| Sharpe ratio < 0.5 | Raise cash — reduce overall exposure | +| Max drawdown > 15 % | Go defensive — reduce equity allocation | +| Daily 95 % VaR > 3 % | Reduce position sizes to lower tail risk | + +--- + +## 10-Phase Implementation Roadmap + +| Phase | Deliverable | Effort | +|-------|-------------|--------| +| 1 | Data foundation (this PR) — models, DB, filesystem, repository | ~2–3 days | +| 2 | Holding Reviewer Agent | ~1 day | +| 3 | Risk metrics engine (Phase 3 of workflow) | ~1–2 days | +| 4 | Portfolio Manager Agent (LLM, structured output) | ~2 days | +| 5 | Trade execution engine (Phase 5 of workflow) | ~1 day | +| 6 | Full orchestration graph (LangGraph) tying all phases | ~2 days | +| 7 | CLI command `pm run` | ~0.5 days | +| 8 | End-to-end integration tests | ~1 day | +| 9 | Performance tuning + concurrency (Phase 2 parallelism) | ~1 day | +| 10 | Documentation, memory system update, PR review | ~0.5 days | + +**Total estimate: ~15–22 days** + +--- + +## References + +- `tradingagents/pipeline/macro_bridge.py` — existing scan → per-ticker bridge +- `tradingagents/report_paths.py` — filesystem path conventions +- `tradingagents/default_config.py` — config pattern to follow +- `tradingagents/agents/scanners/` — scanner agent examples +- `tradingagents/graph/scanner_setup.py` — parallel graph node patterns diff --git a/docs/portfolio/01_phase1_plan.md b/docs/portfolio/01_phase1_plan.md new file mode 100644 index 00000000..dd14c2ee --- /dev/null +++ b/docs/portfolio/01_phase1_plan.md @@ -0,0 +1,352 @@ +# Phase 1 Implementation Plan — Data Foundation + + + +## Goal + +Build the data foundation layer for the Portfolio Manager feature. + +After Phase 1 you should be able to: +- Create and retrieve portfolios +- Manage holdings (add, update, remove) with correct avg-cost-basis accounting +- Record mock trades +- Take immutable portfolio snapshots +- Save and load all report types (scans, analysis, holding reviews, risk, PM decisions) +- Pass a 90 %+ test coverage gate on all new modules + +--- + +## File Structure + +``` +tradingagents/portfolio/ +├── __init__.py ← public exports +├── models.py ← Portfolio, Holding, Trade, PortfolioSnapshot dataclasses +├── config.py ← PORTFOLIO_CONFIG dict + helpers +├── exceptions.py ← domain exception hierarchy +├── supabase_client.py ← Supabase CRUD wrapper +├── report_store.py ← Filesystem document storage +├── repository.py ← Unified data-access façade (Supabase + filesystem) +└── migrations/ + └── 001_initial_schema.sql + +tests/portfolio/ +├── __init__.py +├── conftest.py ← shared fixtures +├── test_models.py +├── test_report_store.py +└── test_repository.py +``` + +--- + +## Task 1 — Data Models (`models.py`) + +**Estimated effort:** 2–3 h + +### Deliverables + +Four dataclasses fully type-annotated: + +- `Portfolio` +- `Holding` +- `Trade` +- `PortfolioSnapshot` + +Each class must implement: +- `to_dict() -> dict` — serialise for DB / JSON +- `from_dict(data: dict) -> Self` — deserialise from DB / JSON +- `enrich(**kwargs)` — attach runtime-computed fields (prices, weights, P&L) + +### Field Specifications + +See `docs/portfolio/02_data_models.md` for full field tables. + +### Acceptance Criteria + +- All fields have explicit Python type annotations +- `to_dict()` → `from_dict()` round-trip is lossless for all fields +- `enrich()` correctly computes `current_value`, `unrealized_pnl`, `unrealized_pnl_pct`, `weight` +- 100 % line coverage in `test_models.py` + +--- + +## Task 2 — Portfolio Config (`config.py`) + +**Estimated effort:** 1 h + +### Deliverables + +```python +PORTFOLIO_CONFIG: dict # all tunable parameters +get_portfolio_config() -> dict # returns merged config (defaults + env overrides) +validate_config(cfg: dict) # raises ValueError on invalid values +``` + +### Environment Variables + +All are optional and default to the values shown: + +| Env Var | Default | Description | +|---------|---------|-------------| +| `SUPABASE_URL` | `""` | Supabase project URL | +| `SUPABASE_KEY` | `""` | Supabase anon/service key | +| `PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports | +| `PM_MAX_POSITIONS` | `15` | Max number of open positions | +| `PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight | +| `PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight | +| `PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve | +| `PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) | + +### Acceptance Criteria + +- All env vars read with `os.getenv`, defaulting gracefully when unset +- `validate_config` raises `ValueError` for `max_position_pct > 1.0`, + `min_cash_pct < 0`, `max_positions < 1`, etc. + +--- + +## Task 3 — Supabase Migration (`migrations/001_initial_schema.sql`) + +**Estimated effort:** 1–2 h + +### Deliverables + +A single idempotent SQL file (`CREATE TABLE IF NOT EXISTS`) that creates: + +- `portfolios` table +- `holdings` table +- `trades` table +- `snapshots` table +- All CHECK constraints +- All FOREIGN KEY constraints +- All indexes (PK + query-path indexes) +- `updated_at` trigger function + triggers on portfolios, holdings + +See `docs/portfolio/03_database_schema.md` for full schema. + +### Acceptance Criteria + +- File runs without error on a fresh Supabase PostgreSQL database +- All tables created with correct column types and constraints +- `updated_at` auto-updates on every row modification + +--- + +## Task 4 — Supabase Client (`supabase_client.py`) + +**Estimated effort:** 3–4 h + +### Deliverables + +`SupabaseClient` class (singleton pattern) with: + +**Portfolio CRUD** +- `create_portfolio(portfolio: Portfolio) -> Portfolio` +- `get_portfolio(portfolio_id: str) -> Portfolio` +- `list_portfolios() -> list[Portfolio]` +- `update_portfolio(portfolio: Portfolio) -> Portfolio` +- `delete_portfolio(portfolio_id: str) -> None` + +**Holdings CRUD** +- `upsert_holding(holding: Holding) -> Holding` +- `get_holding(portfolio_id: str, ticker: str) -> Holding | None` +- `list_holdings(portfolio_id: str) -> list[Holding]` +- `delete_holding(portfolio_id: str, ticker: str) -> None` + +**Trades** +- `record_trade(trade: Trade) -> Trade` +- `list_trades(portfolio_id: str, limit: int = 100) -> list[Trade]` + +**Snapshots** +- `save_snapshot(snapshot: PortfolioSnapshot) -> PortfolioSnapshot` +- `get_latest_snapshot(portfolio_id: str) -> PortfolioSnapshot | None` +- `list_snapshots(portfolio_id: str, limit: int = 30) -> list[PortfolioSnapshot]` + +### Error Handling + +All methods translate Supabase/HTTP errors into domain exceptions (see Task below). +Methods that query a single row raise `PortfolioNotFoundError` when no row is found. + +### Acceptance Criteria + +- Singleton — only one Supabase connection per process +- All public methods fully type-annotated +- Supabase integration tests auto-skip when `SUPABASE_URL` is unset + +--- + +## Task 5 — Report Store (`report_store.py`) + +**Estimated effort:** 3–4 h + +### Deliverables + +`ReportStore` class with typed save/load methods for each report type: + +| Method | Description | +|--------|-------------| +| `save_scan(date, data)` | Save macro scan JSON | +| `load_scan(date)` | Load macro scan JSON | +| `save_analysis(date, ticker, data)` | Save per-ticker analysis report | +| `load_analysis(date, ticker)` | Load per-ticker analysis report | +| `save_holding_review(date, ticker, data)` | Save holding reviewer output | +| `load_holding_review(date, ticker)` | Load holding reviewer output | +| `save_risk_metrics(date, portfolio_id, data)` | Save risk computation output | +| `load_risk_metrics(date, portfolio_id)` | Load risk computation output | +| `save_pm_decision(date, portfolio_id, data)` | Save PM agent decision JSON + MD | +| `load_pm_decision(date, portfolio_id)` | Load PM agent decision JSON | +| `list_pm_decisions(portfolio_id)` | List all saved PM decision paths | + +### Directory Convention + +``` +reports/daily/{date}/ +├── market/ +│ └── macro_scan_summary.json ← save_scan / load_scan +├── {TICKER}/ +│ └── complete_report.md ← save_analysis / load_analysis (existing) +└── portfolio/ + ├── {TICKER}_holding_review.json ← save_holding_review / load_holding_review + ├── {portfolio_id}_risk_metrics.json + ├── {portfolio_id}_pm_decision.json + └── {portfolio_id}_pm_decision.md (human-readable version) +``` + +### Acceptance Criteria + +- Directories created automatically on first write +- `load_*` returns `None` when the file doesn't exist (no exception) +- JSON serialisation uses `json.dumps(indent=2)` + +--- + +## Task 6 — Repository (`repository.py`) + +**Estimated effort:** 4–5 h + +### Deliverables + +`PortfolioRepository` class — unified façade over `SupabaseClient` + `ReportStore`. + +**Key business logic:** + +``` +add_holding(portfolio_id, ticker, shares, price): + existing = client.get_holding(portfolio_id, ticker) + if existing: + new_avg_cost = (existing.avg_cost * existing.shares + price * shares) + / (existing.shares + shares) + holding.shares += shares + holding.avg_cost = new_avg_cost + else: + holding = Holding(ticker=ticker, shares=shares, avg_cost=price, ...) + portfolio.cash -= shares * price # deduct cash + client.upsert_holding(holding) + client.update_portfolio(portfolio) # persist cash change + +remove_holding(portfolio_id, ticker, shares, price): + existing = client.get_holding(portfolio_id, ticker) + if existing.shares < shares: + raise InsufficientSharesError(...) + if shares == existing.shares: + client.delete_holding(portfolio_id, ticker) + else: + existing.shares -= shares + client.upsert_holding(existing) + portfolio.cash += shares * price # credit proceeds + client.update_portfolio(portfolio) +``` + +All DB operations execute as a logical unit (best-effort; full Supabase transactions +require PG functions — deferred to Phase 3+). + +### Acceptance Criteria + +- `add_holding` correctly updates avg cost basis on repeated buys +- `remove_holding` raises `InsufficientSharesError` when shares would go negative +- `add_holding` raises `InsufficientCashError` when cash < `shares * price` +- Repository integration tests auto-skip when `SUPABASE_URL` is unset + +--- + +## Task 7 — Package Setup + +**Estimated effort:** 1 h + +### Deliverables + +1. `tradingagents/portfolio/__init__.py` — export public symbols +2. `pyproject.toml` — add `supabase>=2.0.0` to dependencies +3. `.env.example` — add new env vars (`SUPABASE_URL`, `SUPABASE_KEY`, `PM_*`) +4. `tradingagents/default_config.py` — merge `PORTFOLIO_CONFIG` into `DEFAULT_CONFIG` + under a `"portfolio"` key (non-breaking addition) + +### Acceptance Criteria + +- `from tradingagents.portfolio import PortfolioRepository` works after install +- `pip install -e ".[dev]"` succeeds with the new dependency + +--- + +## Task 8 — Tests + +**Estimated effort:** 3–4 h + +### Test List + +**`test_models.py`** +- `test_portfolio_to_dict_round_trip` +- `test_holding_to_dict_round_trip` +- `test_trade_to_dict_round_trip` +- `test_snapshot_to_dict_round_trip` +- `test_holding_enrich_computes_current_value` +- `test_holding_enrich_computes_unrealized_pnl` +- `test_holding_enrich_computes_weight` +- `test_holding_enrich_handles_zero_cost` + +**`test_report_store.py`** +- `test_save_and_load_scan` +- `test_save_and_load_analysis` +- `test_save_and_load_holding_review` +- `test_save_and_load_risk_metrics` +- `test_save_and_load_pm_decision_json` +- `test_load_returns_none_for_missing_file` +- `test_list_pm_decisions` +- `test_directories_created_on_write` + +**`test_repository.py`** (Supabase tests skip when `SUPABASE_URL` unset) +- `test_add_holding_new_position` +- `test_add_holding_updates_avg_cost` +- `test_remove_holding_full_position` +- `test_remove_holding_partial_position` +- `test_remove_holding_raises_insufficient_shares` +- `test_add_holding_raises_insufficient_cash` +- `test_record_and_list_trades` +- `test_save_and_load_snapshot` + +### Coverage Target + +90 %+ for `models.py` and `report_store.py`. +Integration tests (`test_repository.py`) auto-skip when Supabase is unavailable. + +--- + +## Execution Order + +``` +Day 1 (parallel tracks) + Track A: Task 1 (models) → Task 3 (SQL migration) + Track B: Task 2 (config) → Task 7 (package setup partial) + +Day 2 (parallel tracks) + Track A: Task 4 (SupabaseClient) + Track B: Task 5 (ReportStore) + +Day 3 + Task 6 (Repository) + Task 8 (Tests) + Task 7 (package setup final — pyproject.toml, .env.example) +``` + +**Total estimate: ~18–24 hours** diff --git a/docs/portfolio/02_data_models.md b/docs/portfolio/02_data_models.md new file mode 100644 index 00000000..b4ac86f5 --- /dev/null +++ b/docs/portfolio/02_data_models.md @@ -0,0 +1,230 @@ +# Data Models — Full Specification + + + +All models live in `tradingagents/portfolio/models.py` as Python `dataclass` types. +They must be fully type-annotated and support lossless `to_dict` / `from_dict` +round-trips. + +--- + +## `Portfolio` + +Represents a single managed portfolio (one user may eventually have multiple). + +### Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `portfolio_id` | `str` | Yes | UUID, primary key | +| `name` | `str` | Yes | Human-readable name, e.g. "Main Portfolio" | +| `cash` | `float` | Yes | Available cash balance in USD | +| `initial_cash` | `float` | Yes | Starting capital (immutable after creation) | +| `currency` | `str` | Yes | ISO 4217 code, default `"USD"` | +| `created_at` | `str` | Yes | ISO-8601 UTC datetime string | +| `updated_at` | `str` | Yes | ISO-8601 UTC datetime string | +| `report_path` | `str \| None` | No | Filesystem path to today's portfolio report dir | +| `metadata` | `dict` | No | Free-form JSON for agent notes / tags | + +### Computed / Derived Fields (not stored in DB) + +| Field | Type | Description | +|-------|------|-------------| +| `total_value` | `float` | `cash` + sum of all holding `current_value` | +| `equity_value` | `float` | sum of all holding `current_value` | +| `cash_pct` | `float` | `cash / total_value` | + +### Methods + +```python +def to_dict(self) -> dict: + """Serialise all stored fields to a flat dict suitable for JSON / Supabase insert.""" + +def from_dict(cls, data: dict) -> "Portfolio": + """Deserialise from a DB row or JSON dict. Missing optional fields default gracefully.""" + +def enrich(self, holdings: list["Holding"]) -> "Portfolio": + """Compute total_value, equity_value, cash_pct from the provided holdings list.""" +``` + +--- + +## `Holding` + +Represents a single open position within a portfolio. + +### Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `holding_id` | `str` | Yes | UUID, primary key | +| `portfolio_id` | `str` | Yes | FK → portfolios.portfolio_id | +| `ticker` | `str` | Yes | Stock ticker symbol, e.g. `"AAPL"` | +| `shares` | `float` | Yes | Number of shares held | +| `avg_cost` | `float` | Yes | Average cost basis per share (USD) | +| `sector` | `str \| None` | No | GICS sector name | +| `industry` | `str \| None` | No | GICS industry name | +| `created_at` | `str` | Yes | ISO-8601 UTC datetime string | +| `updated_at` | `str` | Yes | ISO-8601 UTC datetime string | + +### Runtime-Computed Fields (not stored in DB) + +These are populated by `enrich()` and available for agent/analysis use: + +| Field | Type | Description | +|-------|------|-------------| +| `current_price` | `float \| None` | Latest market price per share | +| `current_value` | `float \| None` | `current_price * shares` | +| `cost_basis` | `float` | `avg_cost * shares` | +| `unrealized_pnl` | `float \| None` | `current_value - cost_basis` | +| `unrealized_pnl_pct` | `float \| None` | `unrealized_pnl / cost_basis` (0 if cost_basis == 0) | +| `weight` | `float \| None` | `current_value / portfolio_total_value` | + +### Methods + +```python +def to_dict(self) -> dict: + """Serialise stored fields only (not runtime-computed fields).""" + +def from_dict(cls, data: dict) -> "Holding": + """Deserialise from DB row or JSON dict.""" + +def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding": + """ + Populate runtime-computed fields in-place and return self. + + Args: + current_price: Latest market price for this ticker. + portfolio_total_value: Total portfolio value (cash + equity) for weight calc. + """ +``` + +--- + +## `Trade` + +Immutable record of a single mock trade execution. Never modified after creation. + +### Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `trade_id` | `str` | Yes | UUID, primary key | +| `portfolio_id` | `str` | Yes | FK → portfolios.portfolio_id | +| `ticker` | `str` | Yes | Stock ticker symbol | +| `action` | `str` | Yes | `"BUY"` or `"SELL"` | +| `shares` | `float` | Yes | Number of shares traded | +| `price` | `float` | Yes | Execution price per share (USD) | +| `total_value` | `float` | Yes | `shares * price` | +| `trade_date` | `str` | Yes | ISO-8601 UTC datetime of execution | +| `rationale` | `str \| None` | No | PM agent rationale for this trade | +| `signal_source` | `str \| None` | No | `"scanner"`, `"holding_review"`, `"pm_agent"` | +| `metadata` | `dict` | No | Free-form JSON | + +### Methods + +```python +def to_dict(self) -> dict: + """Serialise all fields.""" + +def from_dict(cls, data: dict) -> "Trade": + """Deserialise from DB row or JSON dict.""" +``` + +--- + +## `PortfolioSnapshot` + +Point-in-time immutable record of the portfolio state. Taken after every trade +execution session (Phase 5 of the workflow). Used for performance tracking. + +### Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `snapshot_id` | `str` | Yes | UUID, primary key | +| `portfolio_id` | `str` | Yes | FK → portfolios.portfolio_id | +| `snapshot_date` | `str` | Yes | ISO-8601 UTC datetime | +| `total_value` | `float` | Yes | Cash + equity at snapshot time | +| `cash` | `float` | Yes | Cash balance at snapshot time | +| `equity_value` | `float` | Yes | Sum of position values at snapshot time | +| `num_positions` | `int` | Yes | Number of open positions | +| `holdings_snapshot` | `list[dict]` | Yes | Serialised list of holding dicts (as-of) | +| `metadata` | `dict` | No | Free-form JSON (e.g. PM decision path) | + +### Methods + +```python +def to_dict(self) -> dict: + """Serialise all fields. `holdings_snapshot` is already a list[dict].""" + +def from_dict(cls, data: dict) -> "PortfolioSnapshot": + """Deserialise. `holdings_snapshot` parsed from JSON string if needed.""" +``` + +--- + +## Serialisation Contract + +### `to_dict()` + +- Returns a flat `dict[str, Any]` +- All values must be JSON-serialisable (str, int, float, bool, list, dict, None) +- `datetime` objects → ISO-8601 string (`isoformat()`) +- `Decimal` values → `float` +- Runtime-computed fields (`current_price`, `weight`, etc.) are **excluded** +- Complex nested fields (`metadata`, `holdings_snapshot`) are included as-is + +### `from_dict()` + +- Class method; must be callable as `Portfolio.from_dict(row)` +- Handles missing optional fields with `data.get("field", default)` +- Does **not** raise on extra keys in `data` +- Does **not** populate runtime-computed fields (call `enrich()` separately) + +--- + +## Enrichment Logic + +### `Holding.enrich(current_price, portfolio_total_value)` + +```python +self.current_price = current_price +self.current_value = current_price * self.shares +self.cost_basis = self.avg_cost * self.shares +self.unrealized_pnl = self.current_value - self.cost_basis +if self.cost_basis > 0: + self.unrealized_pnl_pct = self.unrealized_pnl / self.cost_basis +else: + self.unrealized_pnl_pct = 0.0 +if portfolio_total_value > 0: + self.weight = self.current_value / portfolio_total_value +else: + self.weight = 0.0 +return self +``` + +### `Portfolio.enrich(holdings)` + +```python +self.equity_value = sum(h.current_value or 0 for h in holdings) +self.total_value = self.cash + self.equity_value +if self.total_value > 0: + self.cash_pct = self.cash / self.total_value +else: + self.cash_pct = 1.0 +return self +``` + +--- + +## Type Alias Reference + +```python +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any +``` + +All `metadata` fields use `dict[str, Any]` with `field(default_factory=dict)`. +All optional fields default to `None` unless noted otherwise. diff --git a/docs/portfolio/03_database_schema.md b/docs/portfolio/03_database_schema.md new file mode 100644 index 00000000..757f4549 --- /dev/null +++ b/docs/portfolio/03_database_schema.md @@ -0,0 +1,305 @@ +# Database & Filesystem Schema + + + +## Supabase (PostgreSQL) Schema + +All tables are created in the `public` schema (Supabase default). + +--- + +### `portfolios` + +Stores one row per managed portfolio. + +```sql +CREATE TABLE IF NOT EXISTS portfolios ( + portfolio_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + cash NUMERIC(18,4) NOT NULL CHECK (cash >= 0), + initial_cash NUMERIC(18,4) NOT NULL CHECK (initial_cash > 0), + currency CHAR(3) NOT NULL DEFAULT 'USD', + report_path TEXT, + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +``` + +**Constraints:** +- `cash >= 0` — portfolio can be fully invested but never negative +- `initial_cash > 0` — must start with positive capital +- `currency` is 3-char ISO 4217 code + +--- + +### `holdings` + +Stores one row per open position per portfolio. Row is deleted when shares reach 0. + +```sql +CREATE TABLE IF NOT EXISTS holdings ( + holding_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, + ticker TEXT NOT NULL, + shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), + avg_cost NUMERIC(18,4) NOT NULL CHECK (avg_cost >= 0), + sector TEXT, + industry TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + CONSTRAINT holdings_portfolio_ticker_unique UNIQUE (portfolio_id, ticker) +); +``` + +**Constraints:** +- `shares > 0` — zero-share positions are deleted, not stored +- `avg_cost >= 0` — cost basis is non-negative +- `UNIQUE (portfolio_id, ticker)` — one row per ticker per portfolio (upsert pattern) + +--- + +### `trades` + +Immutable append-only log of every mock trade execution. + +```sql +CREATE TABLE IF NOT EXISTS trades ( + trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, + ticker TEXT NOT NULL, + action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')), + shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), + price NUMERIC(18,4) NOT NULL CHECK (price > 0), + total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0), + trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), + rationale TEXT, + signal_source TEXT, + metadata JSONB NOT NULL DEFAULT '{}', + + CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL')) +); +``` + +**Constraints:** +- `action IN ('BUY', 'SELL')` — only two valid actions +- `shares > 0`, `price > 0` — all quantities positive +- No `updated_at` — trades are immutable + +--- + +### `snapshots` + +Point-in-time portfolio state snapshots taken after each trade session. + +```sql +CREATE TABLE IF NOT EXISTS snapshots ( + snapshot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, + snapshot_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), + total_value NUMERIC(18,4) NOT NULL, + cash NUMERIC(18,4) NOT NULL, + equity_value NUMERIC(18,4) NOT NULL, + num_positions INTEGER NOT NULL CHECK (num_positions >= 0), + holdings_snapshot JSONB NOT NULL DEFAULT '[]', + metadata JSONB NOT NULL DEFAULT '{}' +); +``` + +**Constraints:** +- `num_positions >= 0` — can have 0 positions (fully in cash) +- `holdings_snapshot` is a JSONB array of serialised `Holding.to_dict()` objects +- No `updated_at` — snapshots are immutable + +--- + +## Indexes + +```sql +-- portfolios: fast lookup by name +CREATE INDEX IF NOT EXISTS idx_portfolios_name + ON portfolios (name); + +-- holdings: list all holdings for a portfolio (most common query) +CREATE INDEX IF NOT EXISTS idx_holdings_portfolio_id + ON holdings (portfolio_id); + +-- holdings: fast ticker lookup within a portfolio +CREATE INDEX IF NOT EXISTS idx_holdings_portfolio_ticker + ON holdings (portfolio_id, ticker); + +-- trades: list recent trades for a portfolio, newest first +CREATE INDEX IF NOT EXISTS idx_trades_portfolio_id_date + ON trades (portfolio_id, trade_date DESC); + +-- trades: filter by ticker within a portfolio +CREATE INDEX IF NOT EXISTS idx_trades_portfolio_ticker + ON trades (portfolio_id, ticker); + +-- snapshots: get latest snapshot for a portfolio +CREATE INDEX IF NOT EXISTS idx_snapshots_portfolio_id_date + ON snapshots (portfolio_id, snapshot_date DESC); +``` + +--- + +## `updated_at` Trigger + +Automatically updates `updated_at` on every row modification for `portfolios` +and `holdings` (trades and snapshots are immutable). + +```sql +-- Trigger function (shared across tables) +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Apply to portfolios +CREATE OR REPLACE TRIGGER trg_portfolios_updated_at + BEFORE UPDATE ON portfolios + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); + +-- Apply to holdings +CREATE OR REPLACE TRIGGER trg_holdings_updated_at + BEFORE UPDATE ON holdings + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); +``` + +--- + +## Example Queries + +### Get active portfolio with cash balance + +```sql +SELECT portfolio_id, name, cash, initial_cash, currency +FROM portfolios +WHERE portfolio_id = $1; +``` + +### Get all holdings with sector summary + +```sql +SELECT + ticker, + shares, + avg_cost, + shares * avg_cost AS cost_basis, + sector +FROM holdings +WHERE portfolio_id = $1 +ORDER BY shares * avg_cost DESC; +``` + +### Sector concentration (cost-basis weighted) + +```sql +SELECT + COALESCE(sector, 'Unknown') AS sector, + SUM(shares * avg_cost) AS sector_cost_basis +FROM holdings +WHERE portfolio_id = $1 +GROUP BY sector +ORDER BY sector_cost_basis DESC; +``` + +### Recent 20 trades for a portfolio + +```sql +SELECT ticker, action, shares, price, total_value, trade_date, rationale +FROM trades +WHERE portfolio_id = $1 +ORDER BY trade_date DESC +LIMIT 20; +``` + +### Latest portfolio snapshot + +```sql +SELECT * +FROM snapshots +WHERE portfolio_id = $1 +ORDER BY snapshot_date DESC +LIMIT 1; +``` + +### Portfolio performance over time (snapshot series) + +```sql +SELECT snapshot_date, total_value, cash, equity_value, num_positions +FROM snapshots +WHERE portfolio_id = $1 +ORDER BY snapshot_date ASC; +``` + +--- + +## Filesystem Directory Structure + +Reports and documents are stored under the project's `reports/` directory using +the existing convention from `tradingagents/report_paths.py`. + +``` +reports/ +└── daily/ + └── {YYYY-MM-DD}/ + ├── market/ + │ ├── geopolitical_report.md + │ ├── market_movers_report.md + │ ├── sector_report.md + │ ├── industry_deep_dive_report.md + │ ├── macro_synthesis_report.md + │ └── macro_scan_summary.json ← ReportStore.save_scan / load_scan + │ + ├── {TICKER}/ ← one dir per analysed ticker + │ ├── 1_analysts/ + │ ├── 2_research/ + │ ├── 3_trader/ + │ ├── 4_risk/ + │ ├── complete_report.md ← ReportStore.save_analysis / load_analysis + │ └── eval/ + │ └── full_states_log.json + │ + ├── daily_digest.md + │ + └── portfolio/ ← NEW: portfolio manager artifacts + ├── {TICKER}_holding_review.json ← ReportStore.save_holding_review + ├── {portfolio_id}_risk_metrics.json + ├── {portfolio_id}_pm_decision.json + └── {portfolio_id}_pm_decision.md (human-readable) +``` + +--- + +## Supabase ↔ Filesystem Link + +The `portfolios.report_path` column stores the **absolute or relative path** to the +daily portfolio subdirectory: + +``` +report_path = "reports/daily/2026-03-20/portfolio" +``` + +This allows the Repository layer to load the PM decision, risk metrics, and holding +reviews by constructing: + +```python +Path(portfolio.report_path) / f"{portfolio_id}_pm_decision.json" +``` + +The path is set by the Repository after the first write on each run day. + +--- + +## Schema Version Notes + +- Migration file: `tradingagents/portfolio/migrations/001_initial_schema.sql` +- All `CREATE TABLE` and `CREATE INDEX` use `IF NOT EXISTS` — safe to re-run +- `CREATE OR REPLACE TRIGGER` / `CREATE OR REPLACE FUNCTION` — idempotent +- Supabase project dashboard: run via SQL Editor or the Supabase CLI + (`supabase db push`) diff --git a/docs/portfolio/04_repository_api.md b/docs/portfolio/04_repository_api.md new file mode 100644 index 00000000..8828af91 --- /dev/null +++ b/docs/portfolio/04_repository_api.md @@ -0,0 +1,492 @@ +# Repository Layer API + + + +This document is the authoritative API reference for all classes in +`tradingagents/portfolio/`. + +--- + +## Exception Hierarchy + +Defined in `tradingagents/portfolio/exceptions.py`. + +``` +PortfolioError # Base exception for all portfolio errors +├── PortfolioNotFoundError # Requested portfolio_id does not exist +├── HoldingNotFoundError # Requested holding (portfolio_id, ticker) does not exist +├── DuplicatePortfolioError # Portfolio name or ID already exists +├── InsufficientCashError # Not enough cash for a BUY trade +├── InsufficientSharesError # Not enough shares for a SELL trade +├── ConstraintViolationError # PM constraint breached (position size, sector, cash) +└── ReportStoreError # Filesystem read/write failure +``` + +### Usage + +```python +from tradingagents.portfolio.exceptions import ( + PortfolioError, + PortfolioNotFoundError, + InsufficientCashError, + InsufficientSharesError, +) + +try: + repo.add_holding(portfolio_id, "AAPL", shares=100, price=195.50) +except InsufficientCashError as e: + print(f"Cannot buy: {e}") +``` + +--- + +## `SupabaseClient` + +Location: `tradingagents/portfolio/supabase_client.py` + +Thin wrapper around the `supabase-py` client that: +- Manages a singleton connection +- Translates HTTP / Supabase errors into domain exceptions +- Converts raw DB rows into model instances + +### Constructor / Singleton + +```python +client = SupabaseClient.get_instance() +# or +client = SupabaseClient(url=SUPABASE_URL, key=SUPABASE_KEY) +``` + +### Portfolio Methods + +```python +def create_portfolio(self, portfolio: Portfolio) -> Portfolio: + """Insert a new portfolio row. + + Raises: + DuplicatePortfolioError: If portfolio_id already exists. + """ + +def get_portfolio(self, portfolio_id: str) -> Portfolio: + """Fetch a portfolio by ID. + + Raises: + PortfolioNotFoundError: If no portfolio has that ID. + """ + +def list_portfolios(self) -> list[Portfolio]: + """Return all portfolios ordered by created_at DESC.""" + +def update_portfolio(self, portfolio: Portfolio) -> Portfolio: + """Update mutable fields (cash, report_path, metadata, updated_at). + + Raises: + PortfolioNotFoundError: If portfolio_id does not exist. + """ + +def delete_portfolio(self, portfolio_id: str) -> None: + """Delete a portfolio and all its holdings, trades, and snapshots (CASCADE). + + Raises: + PortfolioNotFoundError: If portfolio_id does not exist. + """ +``` + +### Holdings Methods + +```python +def upsert_holding(self, holding: Holding) -> Holding: + """Insert or update a holding row (upsert on portfolio_id + ticker). + + Returns the holding with updated DB-assigned fields (updated_at). + """ + +def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None: + """Return the holding for (portfolio_id, ticker), or None if not found.""" + +def list_holdings(self, portfolio_id: str) -> list[Holding]: + """Return all holdings for a portfolio ordered by cost_basis DESC.""" + +def delete_holding(self, portfolio_id: str, ticker: str) -> None: + """Delete the holding for (portfolio_id, ticker). + + Raises: + HoldingNotFoundError: If no such holding exists. + """ +``` + +### Trades Methods + +```python +def record_trade(self, trade: Trade) -> Trade: + """Insert a new trade record. Immutable — no update method. + + Returns the trade with DB-assigned trade_id and trade_date. + """ + +def list_trades( + self, + portfolio_id: str, + ticker: str | None = None, + limit: int = 100, +) -> list[Trade]: + """Return recent trades for a portfolio, newest first. + + Args: + portfolio_id: Filter by portfolio. + ticker: Optional additional filter by ticker symbol. + limit: Maximum number of rows to return. + """ +``` + +### Snapshots Methods + +```python +def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot: + """Insert a new snapshot. Immutable — no update method.""" + +def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None: + """Return the most recent snapshot, or None if none exist.""" + +def list_snapshots( + self, + portfolio_id: str, + limit: int = 30, +) -> list[PortfolioSnapshot]: + """Return snapshots newest-first up to limit.""" +``` + +--- + +## `ReportStore` + +Location: `tradingagents/portfolio/report_store.py` + +Filesystem document store for all non-transactional portfolio artifacts. +Integrates with the existing `tradingagents/report_paths.py` path conventions. + +### Constructor + +```python +store = ReportStore(base_dir: str | Path = "reports") +``` + +`base_dir` defaults to `"reports"` (relative to CWD). Override via +`PORTFOLIO_DATA_DIR` env var or config. + +### Scan Methods + +```python +def save_scan(self, date: str, data: dict) -> Path: + """Save macro scan summary JSON. + + Path: {base_dir}/daily/{date}/market/macro_scan_summary.json + + Returns the path written. + """ + +def load_scan(self, date: str) -> dict | None: + """Load macro scan summary. Returns None if file doesn't exist.""" +``` + +### Analysis Methods + +```python +def save_analysis(self, date: str, ticker: str, data: dict) -> Path: + """Save per-ticker analysis report as JSON. + + Path: {base_dir}/daily/{date}/{TICKER}/complete_report.json + """ + +def load_analysis(self, date: str, ticker: str) -> dict | None: + """Load per-ticker analysis JSON. Returns None if file doesn't exist.""" +``` + +### Holding Review Methods + +```python +def save_holding_review(self, date: str, ticker: str, data: dict) -> Path: + """Save holding reviewer output for one ticker. + + Path: {base_dir}/daily/{date}/portfolio/{TICKER}_holding_review.json + """ + +def load_holding_review(self, date: str, ticker: str) -> dict | None: + """Load holding review. Returns None if file doesn't exist.""" +``` + +### Risk Metrics Methods + +```python +def save_risk_metrics( + self, + date: str, + portfolio_id: str, + data: dict, +) -> Path: + """Save risk computation results. + + Path: {base_dir}/daily/{date}/portfolio/{portfolio_id}_risk_metrics.json + """ + +def load_risk_metrics(self, date: str, portfolio_id: str) -> dict | None: + """Load risk metrics. Returns None if file doesn't exist.""" +``` + +### PM Decision Methods + +```python +def save_pm_decision( + self, + date: str, + portfolio_id: str, + data: dict, + markdown: str | None = None, +) -> Path: + """Save PM agent decision. + + JSON path: {base_dir}/daily/{date}/portfolio/{portfolio_id}_pm_decision.json + MD path: {base_dir}/daily/{date}/portfolio/{portfolio_id}_pm_decision.md + (written only when markdown is not None) + + Returns JSON path. + """ + +def load_pm_decision(self, date: str, portfolio_id: str) -> dict | None: + """Load PM decision JSON. Returns None if file doesn't exist.""" + +def list_pm_decisions(self, portfolio_id: str) -> list[Path]: + """Return all saved PM decision JSON paths for portfolio_id, newest first. + + Scans {base_dir}/daily/*/portfolio/{portfolio_id}_pm_decision.json + """ +``` + +--- + +## `PortfolioRepository` + +Location: `tradingagents/portfolio/repository.py` + +Unified façade that combines `SupabaseClient` and `ReportStore`. +This is the **primary interface** for all portfolio operations — callers should +not interact with `SupabaseClient` or `ReportStore` directly. + +### Constructor + +```python +repo = PortfolioRepository( + client: SupabaseClient | None = None, # uses singleton if None + store: ReportStore | None = None, # uses default if None + config: dict | None = None, # uses get_portfolio_config() if None +) +``` + +### Portfolio Lifecycle + +```python +def create_portfolio( + self, + name: str, + initial_cash: float, + currency: str = "USD", +) -> Portfolio: + """Create a new portfolio with the given starting capital. + + Generates a UUID for portfolio_id. Sets cash = initial_cash. + + Raises: + DuplicatePortfolioError: If name is already in use. + ValueError: If initial_cash <= 0. + """ + +def get_portfolio(self, portfolio_id: str) -> Portfolio: + """Fetch portfolio by ID. + + Raises: + PortfolioNotFoundError: If not found. + """ + +def get_portfolio_with_holdings( + self, + portfolio_id: str, + prices: dict[str, float] | None = None, +) -> tuple[Portfolio, list[Holding]]: + """Fetch portfolio + all holdings, optionally enriched with current prices. + + Args: + portfolio_id: Target portfolio. + prices: Optional dict of {ticker: current_price}. When provided, + holdings are enriched and portfolio.total_value is computed. + + Returns: + (Portfolio, list[Holding]) — Portfolio.enrich() called if prices given. + """ +``` + +### Holdings Management + +```python +def add_holding( + self, + portfolio_id: str, + ticker: str, + shares: float, + price: float, + sector: str | None = None, + industry: str | None = None, +) -> Holding: + """Buy shares and update portfolio cash and holdings. + + Business logic: + - Raises InsufficientCashError if portfolio.cash < shares * price + - If holding already exists: updates avg_cost = weighted average + - portfolio.cash -= shares * price + - Records a BUY trade automatically + + Returns the updated/created Holding. + """ + +def remove_holding( + self, + portfolio_id: str, + ticker: str, + shares: float, + price: float, +) -> Holding | None: + """Sell shares and update portfolio cash and holdings. + + Business logic: + - Raises HoldingNotFoundError if no holding exists for ticker + - Raises InsufficientSharesError if holding.shares < shares + - If shares == holding.shares: deletes the holding row, returns None + - Otherwise: decrements holding.shares (avg_cost unchanged on sell) + - portfolio.cash += shares * price + - Records a SELL trade automatically + + Returns the updated Holding (or None if fully sold). + """ +``` + +### Snapshot Management + +```python +def take_snapshot(self, portfolio_id: str, prices: dict[str, float]) -> PortfolioSnapshot: + """Take an immutable snapshot of the current portfolio state. + + Enriches all holdings with current prices, computes total_value, + then persists to Supabase via SupabaseClient.save_snapshot(). + + Returns the saved PortfolioSnapshot. + """ +``` + +### Report Convenience Methods + +```python +def save_pm_decision( + self, + portfolio_id: str, + date: str, + decision: dict, + markdown: str | None = None, +) -> Path: + """Delegate to ReportStore.save_pm_decision and update portfolio.report_path.""" + +def load_pm_decision(self, portfolio_id: str, date: str) -> dict | None: + """Delegate to ReportStore.load_pm_decision.""" + +def save_risk_metrics( + self, + portfolio_id: str, + date: str, + metrics: dict, +) -> Path: + """Delegate to ReportStore.save_risk_metrics.""" + +def load_risk_metrics(self, portfolio_id: str, date: str) -> dict | None: + """Delegate to ReportStore.load_risk_metrics.""" +``` + +--- + +## Avg Cost Basis Calculation + +When buying more shares of an existing holding, the average cost basis is updated +using the **weighted average** formula: + +``` +new_avg_cost = (old_shares * old_avg_cost + new_shares * new_price) + / (old_shares + new_shares) +``` + +When **selling** shares, the average cost basis is **not changed** — only `shares` +is decremented. This follows the FIFO approximation used by most brokerages for +tax-reporting purposes. + +--- + +## Cash Management Rules + +| Operation | Effect on `portfolio.cash` | +|-----------|---------------------------| +| BUY `n` shares at `p` | `cash -= n * p` | +| SELL `n` shares at `p` | `cash += n * p` | +| Snapshot | Read-only | +| Portfolio creation | `cash = initial_cash` | + +Cash can never go below 0 after a trade. `add_holding` raises +`InsufficientCashError` if the trade would exceed available cash. + +--- + +## Example Usage + +```python +from tradingagents.portfolio import PortfolioRepository + +repo = PortfolioRepository() + +# Create a portfolio +portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0) + +# Buy some shares +holding = repo.add_holding( + portfolio.portfolio_id, + ticker="AAPL", + shares=50, + price=195.50, + sector="Technology", +) +# portfolio.cash is now 100_000 - 50 * 195.50 = 90_225.00 +# holding.avg_cost = 195.50 + +# Buy more (avg cost update) +holding = repo.add_holding( + portfolio.portfolio_id, + ticker="AAPL", + shares=25, + price=200.00, +) +# holding.avg_cost = (50*195.50 + 25*200.00) / 75 = 197.00 + +# Sell half +holding = repo.remove_holding( + portfolio.portfolio_id, + ticker="AAPL", + shares=37, + price=205.00, +) +# portfolio.cash += 37 * 205.00 = 7_585.00 + +# Take snapshot +prices = {"AAPL": 205.00} +snapshot = repo.take_snapshot(portfolio.portfolio_id, prices) + +# Save PM decision +repo.save_pm_decision( + portfolio.portfolio_id, + date="2026-03-20", + decision={"sells": [], "buys": [...], "rationale": "..."}, +) +``` diff --git a/tests/portfolio/__init__.py b/tests/portfolio/__init__.py new file mode 100644 index 00000000..a9b59765 --- /dev/null +++ b/tests/portfolio/__init__.py @@ -0,0 +1 @@ +# tests/portfolio package marker diff --git a/tests/portfolio/conftest.py b/tests/portfolio/conftest.py new file mode 100644 index 00000000..bbbf0d78 --- /dev/null +++ b/tests/portfolio/conftest.py @@ -0,0 +1,110 @@ +"""Shared pytest fixtures for portfolio tests. + +Fixtures provided: + +- ``tmp_reports`` — temporary directory used as ReportStore base_dir +- ``sample_portfolio`` — a Portfolio instance for testing (not persisted) +- ``sample_holding`` — a Holding instance for testing (not persisted) +- ``sample_trade`` — a Trade instance for testing (not persisted) +- ``sample_snapshot`` — a PortfolioSnapshot instance for testing +- ``report_store`` — a ReportStore instance backed by tmp_reports +- ``mock_supabase_client`` — MagicMock of SupabaseClient for unit tests + +Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when +``SUPABASE_URL`` is not set in the environment. +""" + +from __future__ import annotations + +import os +import uuid +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +# --------------------------------------------------------------------------- +# Skip marker for Supabase integration tests +# --------------------------------------------------------------------------- + +requires_supabase = pytest.mark.skipif( + not os.getenv("SUPABASE_URL"), + reason="SUPABASE_URL not set — skipping Supabase integration tests", +) + + +# --------------------------------------------------------------------------- +# Data fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_portfolio_id() -> str: + """Return a fixed UUID for deterministic testing.""" + return "11111111-1111-1111-1111-111111111111" + + +@pytest.fixture +def sample_holding_id() -> str: + """Return a fixed UUID for deterministic testing.""" + return "22222222-2222-2222-2222-222222222222" + + +@pytest.fixture +def sample_portfolio(sample_portfolio_id: str): + """Return an unsaved Portfolio instance for testing.""" + # TODO: implement — construct a Portfolio dataclass with test values + raise NotImplementedError + + +@pytest.fixture +def sample_holding(sample_portfolio_id: str, sample_holding_id: str): + """Return an unsaved Holding instance for testing.""" + # TODO: implement — construct a Holding dataclass with test values + raise NotImplementedError + + +@pytest.fixture +def sample_trade(sample_portfolio_id: str): + """Return an unsaved Trade instance for testing.""" + # TODO: implement — construct a Trade dataclass with test values + raise NotImplementedError + + +@pytest.fixture +def sample_snapshot(sample_portfolio_id: str): + """Return an unsaved PortfolioSnapshot instance for testing.""" + # TODO: implement — construct a PortfolioSnapshot dataclass with test values + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Filesystem fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def tmp_reports(tmp_path: Path) -> Path: + """Temporary reports directory, cleaned up after each test.""" + reports_dir = tmp_path / "reports" + reports_dir.mkdir() + return reports_dir + + +@pytest.fixture +def report_store(tmp_reports: Path): + """ReportStore instance backed by a temporary directory.""" + # TODO: implement — return ReportStore(base_dir=tmp_reports) + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Mock Supabase client fixture +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_supabase_client(): + """MagicMock of SupabaseClient for unit tests that don't hit the DB.""" + # TODO: implement — return MagicMock(spec=SupabaseClient) + raise NotImplementedError diff --git a/tests/portfolio/test_models.py b/tests/portfolio/test_models.py new file mode 100644 index 00000000..72f3c798 --- /dev/null +++ b/tests/portfolio/test_models.py @@ -0,0 +1,132 @@ +"""Tests for tradingagents/portfolio/models.py. + +Tests the four dataclass models: Portfolio, Holding, Trade, PortfolioSnapshot. + +Coverage targets: +- to_dict() / from_dict() round-trips +- enrich() computed-field logic +- Edge cases (zero cost basis, zero portfolio value) + +Run:: + + pytest tests/portfolio/test_models.py -v +""" + +from __future__ import annotations + +import pytest + + +# --------------------------------------------------------------------------- +# Portfolio round-trip +# --------------------------------------------------------------------------- + + +def test_portfolio_to_dict_round_trip(sample_portfolio): + """Portfolio.to_dict() -> Portfolio.from_dict() must be lossless.""" + # TODO: implement + # d = sample_portfolio.to_dict() + # restored = Portfolio.from_dict(d) + # assert restored.portfolio_id == sample_portfolio.portfolio_id + # assert restored.cash == sample_portfolio.cash + # ... all stored fields + raise NotImplementedError + + +def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio): + """to_dict() must not include computed fields (total_value, equity_value, cash_pct).""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Holding round-trip +# --------------------------------------------------------------------------- + + +def test_holding_to_dict_round_trip(sample_holding): + """Holding.to_dict() -> Holding.from_dict() must be lossless.""" + # TODO: implement + raise NotImplementedError + + +def test_holding_to_dict_excludes_runtime_fields(sample_holding): + """to_dict() must not include current_price, current_value, weight, etc.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Trade round-trip +# --------------------------------------------------------------------------- + + +def test_trade_to_dict_round_trip(sample_trade): + """Trade.to_dict() -> Trade.from_dict() must be lossless.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# PortfolioSnapshot round-trip +# --------------------------------------------------------------------------- + + +def test_snapshot_to_dict_round_trip(sample_snapshot): + """PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Holding.enrich() +# --------------------------------------------------------------------------- + + +def test_holding_enrich_computes_current_value(sample_holding): + """enrich() must set current_value = current_price * shares.""" + # TODO: implement + # sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + # assert sample_holding.current_value == 200.0 * sample_holding.shares + raise NotImplementedError + + +def test_holding_enrich_computes_unrealized_pnl(sample_holding): + """enrich() must set unrealized_pnl = current_value - cost_basis.""" + # TODO: implement + raise NotImplementedError + + +def test_holding_enrich_computes_weight(sample_holding): + """enrich() must set weight = current_value / portfolio_total_value.""" + # TODO: implement + raise NotImplementedError + + +def test_holding_enrich_handles_zero_cost(sample_holding): + """When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError).""" + # TODO: implement + raise NotImplementedError + + +def test_holding_enrich_handles_zero_portfolio_value(sample_holding): + """When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError).""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Portfolio.enrich() +# --------------------------------------------------------------------------- + + +def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding): + """Portfolio.enrich() must compute total_value = cash + sum(holding.current_value).""" + # TODO: implement + raise NotImplementedError + + +def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding): + """Portfolio.enrich() must compute cash_pct = cash / total_value.""" + # TODO: implement + raise NotImplementedError diff --git a/tests/portfolio/test_report_store.py b/tests/portfolio/test_report_store.py new file mode 100644 index 00000000..ba111799 --- /dev/null +++ b/tests/portfolio/test_report_store.py @@ -0,0 +1,145 @@ +"""Tests for tradingagents/portfolio/report_store.py. + +Tests filesystem save/load operations for all report types. + +All tests use a temporary directory (``tmp_reports`` fixture) and do not +require Supabase or network access. + +Run:: + + pytest tests/portfolio/test_report_store.py -v +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + + +# --------------------------------------------------------------------------- +# Macro scan +# --------------------------------------------------------------------------- + + +def test_save_and_load_scan(report_store, tmp_reports): + """save_scan() then load_scan() must return the original data.""" + # TODO: implement + # data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"} + # path = report_store.save_scan("2026-03-20", data) + # assert path.exists() + # loaded = report_store.load_scan("2026-03-20") + # assert loaded == data + raise NotImplementedError + + +def test_load_scan_returns_none_for_missing_file(report_store): + """load_scan() must return None when the file does not exist.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Per-ticker analysis +# --------------------------------------------------------------------------- + + +def test_save_and_load_analysis(report_store): + """save_analysis() then load_analysis() must return the original data.""" + # TODO: implement + raise NotImplementedError + + +def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports): + """Ticker symbol must be stored as uppercase in the directory path.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Holding reviews +# --------------------------------------------------------------------------- + + +def test_save_and_load_holding_review(report_store): + """save_holding_review() then load_holding_review() must round-trip.""" + # TODO: implement + raise NotImplementedError + + +def test_load_holding_review_returns_none_for_missing(report_store): + """load_holding_review() must return None when the file does not exist.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Risk metrics +# --------------------------------------------------------------------------- + + +def test_save_and_load_risk_metrics(report_store): + """save_risk_metrics() then load_risk_metrics() must round-trip.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# PM decisions +# --------------------------------------------------------------------------- + + +def test_save_and_load_pm_decision_json(report_store): + """save_pm_decision() then load_pm_decision() must round-trip JSON.""" + # TODO: implement + # decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]} + # report_store.save_pm_decision("2026-03-20", "pid-123", decision) + # loaded = report_store.load_pm_decision("2026-03-20", "pid-123") + # assert loaded == decision + raise NotImplementedError + + +def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports): + """When markdown is passed to save_pm_decision(), .md file must be written.""" + # TODO: implement + raise NotImplementedError + + +def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports): + """When markdown=None, no .md file should be written.""" + # TODO: implement + raise NotImplementedError + + +def test_load_pm_decision_returns_none_for_missing(report_store): + """load_pm_decision() must return None when the file does not exist.""" + # TODO: implement + raise NotImplementedError + + +def test_list_pm_decisions(report_store): + """list_pm_decisions() must return all saved decision paths, newest first.""" + # TODO: implement + # Save decisions for multiple dates, verify order + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Filesystem behaviour +# --------------------------------------------------------------------------- + + +def test_directories_created_on_write(report_store, tmp_reports): + """Directories must be created automatically on first write.""" + # TODO: implement + # assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists() + # report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2}) + # assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir() + raise NotImplementedError + + +def test_json_formatted_with_indent(report_store, tmp_reports): + """Written JSON files must use indent=2 for human readability.""" + # TODO: implement + # Write a file, read the raw bytes, verify indentation + raise NotImplementedError diff --git a/tests/portfolio/test_repository.py b/tests/portfolio/test_repository.py new file mode 100644 index 00000000..3658ab26 --- /dev/null +++ b/tests/portfolio/test_repository.py @@ -0,0 +1,198 @@ +"""Tests for tradingagents/portfolio/repository.py. + +Tests the PortfolioRepository façade — business logic for holdings management, +cash accounting, avg-cost-basis updates, and snapshot creation. + +Supabase integration tests are automatically skipped when ``SUPABASE_URL`` is +not set in the environment (use the ``requires_supabase`` fixture marker). + +Unit tests use ``mock_supabase_client`` to avoid DB access. + +Run (unit tests only):: + + pytest tests/portfolio/test_repository.py -v -k "not integration" + +Run (with Supabase):: + + SUPABASE_URL=... SUPABASE_KEY=... pytest tests/portfolio/test_repository.py -v +""" + +from __future__ import annotations + +import pytest + +from tests.portfolio.conftest import requires_supabase + + +# --------------------------------------------------------------------------- +# add_holding — new position +# --------------------------------------------------------------------------- + + +def test_add_holding_new_position(mock_supabase_client, report_store): + """add_holding() on a ticker not yet held must create a new Holding.""" + # TODO: implement + # repo = PortfolioRepository(client=mock_supabase_client, store=report_store) + # mock portfolio with enough cash + # repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0) + # assert mock_supabase_client.upsert_holding.called + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# add_holding — avg cost basis update +# --------------------------------------------------------------------------- + + +def test_add_holding_updates_avg_cost(mock_supabase_client, report_store): + """add_holding() on an existing position must update avg_cost correctly. + + Formula: new_avg_cost = (old_shares * old_avg_cost + new_shares * price) + / (old_shares + new_shares) + """ + # TODO: implement + # existing holding: 50 shares @ 190.0 + # buy 25 more @ 200.0 + # expected avg_cost = (50*190 + 25*200) / 75 = 193.33... + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# add_holding — insufficient cash +# --------------------------------------------------------------------------- + + +def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store): + """add_holding() must raise InsufficientCashError when cash < shares * price.""" + # TODO: implement + # portfolio with cash=500.0, try to buy 10 shares @ 200.0 (cost=2000) + # with pytest.raises(InsufficientCashError): + # repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0) + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# remove_holding — full position +# --------------------------------------------------------------------------- + + +def test_remove_holding_full_position(mock_supabase_client, report_store): + """remove_holding() selling all shares must delete the holding row.""" + # TODO: implement + # holding: 50 shares + # sell 50 shares → holding deleted, cash credited + # assert mock_supabase_client.delete_holding.called + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# remove_holding — partial position +# --------------------------------------------------------------------------- + + +def test_remove_holding_partial_position(mock_supabase_client, report_store): + """remove_holding() selling a subset must reduce shares, not delete.""" + # TODO: implement + # holding: 50 shares + # sell 20 → holding.shares == 30, avg_cost unchanged + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# remove_holding — errors +# --------------------------------------------------------------------------- + + +def test_remove_holding_raises_insufficient_shares(mock_supabase_client, report_store): + """remove_holding() must raise InsufficientSharesError when shares > held.""" + # TODO: implement + # holding: 10 shares + # try sell 20 → InsufficientSharesError + raise NotImplementedError + + +def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report_store): + """remove_holding() must raise HoldingNotFoundError for unknown tickers.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Cash accounting +# --------------------------------------------------------------------------- + + +def test_add_holding_deducts_cash(mock_supabase_client, report_store): + """add_holding() must reduce portfolio.cash by shares * price.""" + # TODO: implement + # portfolio.cash = 10_000, buy 10 @ 200 → cash should be 8_000 + raise NotImplementedError + + +def test_remove_holding_credits_cash(mock_supabase_client, report_store): + """remove_holding() must increase portfolio.cash by shares * price.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Trade recording +# --------------------------------------------------------------------------- + + +def test_add_holding_records_buy_trade(mock_supabase_client, report_store): + """add_holding() must call client.record_trade() with action='BUY'.""" + # TODO: implement + raise NotImplementedError + + +def test_remove_holding_records_sell_trade(mock_supabase_client, report_store): + """remove_holding() must call client.record_trade() with action='SELL'.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Snapshot +# --------------------------------------------------------------------------- + + +def test_take_snapshot(mock_supabase_client, report_store): + """take_snapshot() must enrich holdings and persist a PortfolioSnapshot.""" + # TODO: implement + # assert mock_supabase_client.save_snapshot.called + # snapshot.total_value == cash + equity + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Supabase integration tests (auto-skip without SUPABASE_URL) +# --------------------------------------------------------------------------- + + +@requires_supabase +def test_integration_create_and_get_portfolio(): + """Integration: create a portfolio, retrieve it, verify fields match.""" + # TODO: implement + raise NotImplementedError + + +@requires_supabase +def test_integration_add_and_remove_holding(): + """Integration: add holding, verify DB row; remove, verify deletion.""" + # TODO: implement + raise NotImplementedError + + +@requires_supabase +def test_integration_record_and_list_trades(): + """Integration: record BUY + SELL trades, list them, verify order.""" + # TODO: implement + raise NotImplementedError + + +@requires_supabase +def test_integration_save_and_load_snapshot(): + """Integration: take snapshot, retrieve latest, verify total_value.""" + # TODO: implement + raise NotImplementedError diff --git a/tradingagents/portfolio/__init__.py b/tradingagents/portfolio/__init__.py new file mode 100644 index 00000000..6bddef33 --- /dev/null +++ b/tradingagents/portfolio/__init__.py @@ -0,0 +1,55 @@ +"""Portfolio Manager — public package exports. + +Import the primary interface classes from this package: + + from tradingagents.portfolio import ( + PortfolioRepository, + Portfolio, + Holding, + Trade, + PortfolioSnapshot, + PortfolioError, + PortfolioNotFoundError, + InsufficientCashError, + InsufficientSharesError, + ) +""" + +from __future__ import annotations + +from tradingagents.portfolio.exceptions import ( + PortfolioError, + PortfolioNotFoundError, + HoldingNotFoundError, + DuplicatePortfolioError, + InsufficientCashError, + InsufficientSharesError, + ConstraintViolationError, + ReportStoreError, +) +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) +from tradingagents.portfolio.repository import PortfolioRepository + +__all__ = [ + # Models + "Portfolio", + "Holding", + "Trade", + "PortfolioSnapshot", + # Repository (primary interface) + "PortfolioRepository", + # Exceptions + "PortfolioError", + "PortfolioNotFoundError", + "HoldingNotFoundError", + "DuplicatePortfolioError", + "InsufficientCashError", + "InsufficientSharesError", + "ConstraintViolationError", + "ReportStoreError", +] diff --git a/tradingagents/portfolio/config.py b/tradingagents/portfolio/config.py new file mode 100644 index 00000000..4ffbc403 --- /dev/null +++ b/tradingagents/portfolio/config.py @@ -0,0 +1,87 @@ +"""Portfolio Manager configuration. + +Reads all portfolio-related settings from environment variables with sensible +defaults. Integrates with the existing ``tradingagents/default_config.py`` +pattern. + +Usage:: + + from tradingagents.portfolio.config import get_portfolio_config, validate_config + + cfg = get_portfolio_config() + validate_config(cfg) + print(cfg["max_positions"]) # 15 + +Environment variables (all optional): + + SUPABASE_URL Supabase project URL (default: "") + SUPABASE_KEY Supabase anon/service role key (default: "") + PORTFOLIO_DATA_DIR Root dir for filesystem reports (default: "reports") + PM_MAX_POSITIONS Max open positions (default: 15) + PM_MAX_POSITION_PCT Max single-position weight 0–1 (default: 0.15) + PM_MAX_SECTOR_PCT Max sector weight 0–1 (default: 0.35) + PM_MIN_CASH_PCT Minimum cash reserve 0–1 (default: 0.05) + PM_DEFAULT_BUDGET Default starting cash in USD (default: 100000.0) +""" + +from __future__ import annotations + +import os + +# --------------------------------------------------------------------------- +# Defaults +# --------------------------------------------------------------------------- + +PORTFOLIO_CONFIG: dict = { + # Supabase connection + "supabase_url": "", + "supabase_key": "", + # Filesystem report root (matches report_paths.py REPORTS_ROOT) + "data_dir": "reports", + # PM constraint defaults + "max_positions": 15, + "max_position_pct": 0.15, + "max_sector_pct": 0.35, + "min_cash_pct": 0.05, + "default_budget": 100_000.0, +} + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def get_portfolio_config() -> dict: + """Return the merged portfolio config (defaults overridden by env vars). + + Reads ``SUPABASE_URL``, ``SUPABASE_KEY``, ``PORTFOLIO_DATA_DIR``, + ``PM_MAX_POSITIONS``, ``PM_MAX_POSITION_PCT``, ``PM_MAX_SECTOR_PCT``, + ``PM_MIN_CASH_PCT``, and ``PM_DEFAULT_BUDGET`` from the environment. + + Returns: + A dict with all portfolio configuration keys. + """ + # TODO: implement — merge PORTFOLIO_CONFIG with env var overrides + raise NotImplementedError + + +def validate_config(cfg: dict) -> None: + """Validate a portfolio config dict, raising ValueError on invalid values. + + Checks: + - ``max_positions >= 1`` + - ``0 < max_position_pct <= 1`` + - ``0 < max_sector_pct <= 1`` + - ``0 <= min_cash_pct < 1`` + - ``default_budget > 0`` + - ``min_cash_pct + max_position_pct <= 1`` (can always meet both constraints) + + Args: + cfg: Config dict as returned by ``get_portfolio_config()``. + + Raises: + ValueError: With a descriptive message on the first failed check. + """ + # TODO: implement + raise NotImplementedError diff --git a/tradingagents/portfolio/exceptions.py b/tradingagents/portfolio/exceptions.py new file mode 100644 index 00000000..ac0cfc87 --- /dev/null +++ b/tradingagents/portfolio/exceptions.py @@ -0,0 +1,60 @@ +"""Domain exception hierarchy for the portfolio management package. + +All exceptions raised by this package inherit from ``PortfolioError`` so that +callers can catch the entire family with a single ``except PortfolioError``. + +Example:: + + from tradingagents.portfolio.exceptions import ( + PortfolioError, + InsufficientCashError, + ) + + try: + repo.add_holding(pid, "AAPL", shares=100, price=195.50) + except InsufficientCashError as e: + print(f"Cannot buy: {e}") + except PortfolioError as e: + print(f"Unexpected portfolio error: {e}") +""" + +from __future__ import annotations + + +class PortfolioError(Exception): + """Base exception for all portfolio-management errors.""" + + +class PortfolioNotFoundError(PortfolioError): + """Raised when a requested portfolio_id does not exist in the database.""" + + +class HoldingNotFoundError(PortfolioError): + """Raised when a requested (portfolio_id, ticker) holding does not exist.""" + + +class DuplicatePortfolioError(PortfolioError): + """Raised when attempting to create a portfolio that already exists.""" + + +class InsufficientCashError(PortfolioError): + """Raised when a BUY order exceeds the portfolio's available cash balance.""" + + +class InsufficientSharesError(PortfolioError): + """Raised when a SELL order exceeds the number of shares held.""" + + +class ConstraintViolationError(PortfolioError): + """Raised when a trade would violate a PM constraint. + + Constraints enforced: + - Max position size (default 15 % of portfolio value) + - Max sector exposure (default 35 % of portfolio value) + - Min cash reserve (default 5 % of portfolio value) + - Max number of positions (default 15) + """ + + +class ReportStoreError(PortfolioError): + """Raised on filesystem read/write failures in ReportStore.""" diff --git a/tradingagents/portfolio/migrations/001_initial_schema.sql b/tradingagents/portfolio/migrations/001_initial_schema.sql new file mode 100644 index 00000000..42724b7a --- /dev/null +++ b/tradingagents/portfolio/migrations/001_initial_schema.sql @@ -0,0 +1,161 @@ +-- ============================================================================= +-- Portfolio Manager Agent — Initial Schema +-- Migration: 001_initial_schema.sql +-- Description: Creates all tables, indexes, and triggers for the portfolio +-- management data layer. +-- Safe to re-run: all statements use IF NOT EXISTS / CREATE OR REPLACE. +-- ============================================================================= + + +-- --------------------------------------------------------------------------- +-- Table: portfolios +-- Purpose: One row per managed portfolio. Tracks cash balance, initial capital, +-- and a pointer to the filesystem report directory. +-- --------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS portfolios ( + portfolio_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + name TEXT NOT NULL, + cash NUMERIC(18,4) NOT NULL CHECK (cash >= 0), + initial_cash NUMERIC(18,4) NOT NULL CHECK (initial_cash > 0), + currency CHAR(3) NOT NULL DEFAULT 'USD', + report_path TEXT, -- relative FS path to daily report dir + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +COMMENT ON TABLE portfolios IS + 'One row per managed portfolio. Tracks cash balance and links to filesystem reports.'; +COMMENT ON COLUMN portfolios.report_path IS + 'Relative path to the daily portfolio report directory, e.g. reports/daily/2026-03-20/portfolio'; +COMMENT ON COLUMN portfolios.metadata IS + 'Free-form JSONB for agent notes, tags, or strategy parameters.'; + + +-- --------------------------------------------------------------------------- +-- Table: holdings +-- Purpose: Current open positions. One row per (portfolio, ticker). Deleted +-- when shares reach zero — zero-share rows are never stored. +-- --------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS holdings ( + holding_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, + ticker TEXT NOT NULL, + shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), + avg_cost NUMERIC(18,4) NOT NULL CHECK (avg_cost >= 0), + sector TEXT, + industry TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + + CONSTRAINT holdings_portfolio_ticker_unique UNIQUE (portfolio_id, ticker) +); + +COMMENT ON TABLE holdings IS + 'Open positions. Upserted on BUY (avg-cost update), deleted when fully sold.'; +COMMENT ON COLUMN holdings.avg_cost IS + 'Weighted-average cost basis per share in portfolio currency.'; + + +-- --------------------------------------------------------------------------- +-- Table: trades +-- Purpose: Immutable append-only log of every mock trade. Never modified. +-- --------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS trades ( + trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, + ticker TEXT NOT NULL, + action TEXT NOT NULL, + shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), + price NUMERIC(18,4) NOT NULL CHECK (price > 0), + total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0), + trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), + rationale TEXT, -- PM agent rationale for this trade + signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent' + metadata JSONB NOT NULL DEFAULT '{}', + + CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL')) +); + +COMMENT ON TABLE trades IS + 'Immutable trade log. Records every mock BUY/SELL with PM rationale.'; +COMMENT ON COLUMN trades.rationale IS + 'Natural-language reason provided by the Portfolio Manager Agent.'; +COMMENT ON COLUMN trades.signal_source IS + 'Which sub-system generated the trade signal: scanner, holding_review, or pm_agent.'; + + +-- --------------------------------------------------------------------------- +-- Table: snapshots +-- Purpose: Immutable point-in-time portfolio state. Taken after each trade +-- execution session for performance tracking and time-series analysis. +-- --------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS snapshots ( + snapshot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, + snapshot_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), + total_value NUMERIC(18,4) NOT NULL, + cash NUMERIC(18,4) NOT NULL, + equity_value NUMERIC(18,4) NOT NULL, + num_positions INTEGER NOT NULL CHECK (num_positions >= 0), + holdings_snapshot JSONB NOT NULL DEFAULT '[]', -- serialised List[Holding.to_dict()] + metadata JSONB NOT NULL DEFAULT '{}' +); + +COMMENT ON TABLE snapshots IS + 'Immutable portfolio snapshots for performance tracking (NAV series).'; +COMMENT ON COLUMN snapshots.holdings_snapshot IS + 'JSONB array of Holding.to_dict() objects at snapshot time.'; + + +-- --------------------------------------------------------------------------- +-- Indexes +-- --------------------------------------------------------------------------- + +-- portfolios: lookup by name (uniqueness enforced at application level) +CREATE INDEX IF NOT EXISTS idx_portfolios_name + ON portfolios (name); + +-- holdings: list all holdings for a portfolio (most frequent query) +CREATE INDEX IF NOT EXISTS idx_holdings_portfolio_id + ON holdings (portfolio_id); + +-- holdings: fast (portfolio, ticker) point lookup for upserts +CREATE INDEX IF NOT EXISTS idx_holdings_portfolio_ticker + ON holdings (portfolio_id, ticker); + +-- trades: list recent trades for a portfolio, newest first +CREATE INDEX IF NOT EXISTS idx_trades_portfolio_id_date + ON trades (portfolio_id, trade_date DESC); + +-- trades: filter by ticker within a portfolio +CREATE INDEX IF NOT EXISTS idx_trades_portfolio_ticker + ON trades (portfolio_id, ticker); + +-- snapshots: get latest snapshot for a portfolio +CREATE INDEX IF NOT EXISTS idx_snapshots_portfolio_id_date + ON snapshots (portfolio_id, snapshot_date DESC); + + +-- --------------------------------------------------------------------------- +-- updated_at trigger +-- Purpose: Automatically sets updated_at = NOW() on every UPDATE for mutable +-- tables (portfolios, holdings). Trades and snapshots are immutable. +-- --------------------------------------------------------------------------- +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Apply to portfolios +CREATE OR REPLACE TRIGGER trg_portfolios_updated_at + BEFORE UPDATE ON portfolios + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); + +-- Apply to holdings +CREATE OR REPLACE TRIGGER trg_holdings_updated_at + BEFORE UPDATE ON holdings + FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); diff --git a/tradingagents/portfolio/models.py b/tradingagents/portfolio/models.py new file mode 100644 index 00000000..9a61fe9b --- /dev/null +++ b/tradingagents/portfolio/models.py @@ -0,0 +1,207 @@ +"""Data models for the Portfolio Manager Agent. + +All models are Python ``dataclass`` types with: +- Full type annotations +- ``to_dict()`` for serialisation (JSON / Supabase) +- ``from_dict()`` class method for deserialisation +- ``enrich()`` for attaching runtime-computed fields + +See ``docs/portfolio/02_data_models.md`` for full field specifications. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +# --------------------------------------------------------------------------- +# Portfolio +# --------------------------------------------------------------------------- + + +@dataclass +class Portfolio: + """A managed investment portfolio. + + Stored fields are persisted to Supabase. Computed fields (total_value, + equity_value, cash_pct) are populated by ``enrich()`` and are *not* + persisted. + """ + + portfolio_id: str + name: str + cash: float + initial_cash: float + currency: str = "USD" + created_at: str = "" + updated_at: str = "" + report_path: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + # Runtime-computed (not stored in DB) + total_value: float | None = field(default=None, repr=False) + equity_value: float | None = field(default=None, repr=False) + cash_pct: float | None = field(default=None, repr=False) + + def to_dict(self) -> dict[str, Any]: + """Serialise stored fields to a flat dict for JSON / Supabase. + + Runtime-computed fields are excluded. + """ + # TODO: implement + raise NotImplementedError + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Portfolio": + """Deserialise from a DB row or JSON dict. + + Missing optional fields default gracefully. Extra keys are ignored. + """ + # TODO: implement + raise NotImplementedError + + def enrich(self, holdings: list["Holding"]) -> "Portfolio": + """Compute total_value, equity_value, cash_pct from holdings. + + Modifies self in-place and returns self for chaining. + + Args: + holdings: List of Holding objects with current_value populated + (i.e., ``holding.enrich()`` already called). + """ + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Holding +# --------------------------------------------------------------------------- + + +@dataclass +class Holding: + """An open position within a portfolio. + + Stored fields are persisted to Supabase. Runtime-computed fields + (current_price, current_value, etc.) are populated by ``enrich()``. + """ + + holding_id: str + portfolio_id: str + ticker: str + shares: float + avg_cost: float + sector: str | None = None + industry: str | None = None + created_at: str = "" + updated_at: str = "" + + # Runtime-computed (not stored in DB) + current_price: float | None = field(default=None, repr=False) + current_value: float | None = field(default=None, repr=False) + cost_basis: float | None = field(default=None, repr=False) + unrealized_pnl: float | None = field(default=None, repr=False) + unrealized_pnl_pct: float | None = field(default=None, repr=False) + weight: float | None = field(default=None, repr=False) + + def to_dict(self) -> dict[str, Any]: + """Serialise stored fields only (runtime-computed fields excluded).""" + # TODO: implement + raise NotImplementedError + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Holding": + """Deserialise from a DB row or JSON dict.""" + # TODO: implement + raise NotImplementedError + + def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding": + """Populate runtime-computed fields in-place and return self. + + Formula: + current_value = current_price * shares + cost_basis = avg_cost * shares + unrealized_pnl = current_value - cost_basis + unrealized_pnl_pct = unrealized_pnl / cost_basis (0 when cost_basis == 0) + weight = current_value / portfolio_total_value (0 when total == 0) + + Args: + current_price: Latest market price for this ticker. + portfolio_total_value: Total portfolio value (cash + equity). + """ + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Trade +# --------------------------------------------------------------------------- + + +@dataclass +class Trade: + """An immutable record of a single mock trade execution. + + Trades are never modified after creation. + """ + + trade_id: str + portfolio_id: str + ticker: str + action: str # "BUY" or "SELL" + shares: float + price: float + total_value: float + trade_date: str = "" + rationale: str | None = None + signal_source: str | None = None # "scanner" | "holding_review" | "pm_agent" + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Serialise all fields.""" + # TODO: implement + raise NotImplementedError + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Trade": + """Deserialise from a DB row or JSON dict.""" + # TODO: implement + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# PortfolioSnapshot +# --------------------------------------------------------------------------- + + +@dataclass +class PortfolioSnapshot: + """An immutable point-in-time snapshot of portfolio state. + + Taken after every trade execution session (Phase 5 of the PM workflow). + Used for NAV time-series, performance attribution, and risk backtesting. + """ + + snapshot_id: str + portfolio_id: str + snapshot_date: str + total_value: float + cash: float + equity_value: float + num_positions: int + holdings_snapshot: list[dict[str, Any]] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Serialise all fields. ``holdings_snapshot`` is already a list[dict].""" + # TODO: implement + raise NotImplementedError + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot": + """Deserialise from DB row or JSON dict. + + ``holdings_snapshot`` is parsed from a JSON string when needed. + """ + # TODO: implement + raise NotImplementedError diff --git a/tradingagents/portfolio/report_store.py b/tradingagents/portfolio/report_store.py new file mode 100644 index 00000000..5ed8594a --- /dev/null +++ b/tradingagents/portfolio/report_store.py @@ -0,0 +1,245 @@ +"""Filesystem document store for Portfolio Manager reports. + +Saves and loads all non-transactional portfolio artifacts (scans, per-ticker +analysis, holding reviews, risk metrics, PM decisions) using the existing +``tradingagents/report_paths.py`` path convention. + +Directory layout:: + + reports/daily/{date}/ + ├── market/ + │ └── macro_scan_summary.json ← save_scan / load_scan + ├── {TICKER}/ + │ └── complete_report.json ← save_analysis / load_analysis + └── portfolio/ + ├── {TICKER}_holding_review.json ← save/load_holding_review + ├── {portfolio_id}_risk_metrics.json + ├── {portfolio_id}_pm_decision.json + └── {portfolio_id}_pm_decision.md + +Usage:: + + from tradingagents.portfolio.report_store import ReportStore + + store = ReportStore() + store.save_scan("2026-03-20", {"watchlist": [...]}) + data = store.load_scan("2026-03-20") +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + + +class ReportStore: + """Filesystem document store for all portfolio-related reports. + + Directories are created automatically on first write. + All load methods return ``None`` when the file does not exist. + """ + + def __init__(self, base_dir: str | Path = "reports") -> None: + """Initialise the store with a base reports directory. + + Args: + base_dir: Root directory for all reports. Defaults to ``"reports"`` + (relative to CWD), matching ``report_paths.REPORTS_ROOT``. + Override via the ``PORTFOLIO_DATA_DIR`` env var or + ``get_portfolio_config()["data_dir"]``. + """ + # TODO: implement — store Path(base_dir), resolve as needed + raise NotImplementedError + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _portfolio_dir(self, date: str) -> Path: + """Return the portfolio subdirectory for a given date. + + Path: ``{base_dir}/daily/{date}/portfolio/`` + """ + # TODO: implement + raise NotImplementedError + + def _write_json(self, path: Path, data: dict[str, Any]) -> Path: + """Write a dict to a JSON file, creating parent directories as needed. + + Args: + path: Target file path. + data: Data to serialise. + + Returns: + The path written. + + Raises: + ReportStoreError: On filesystem write failure. + """ + # TODO: implement + raise NotImplementedError + + def _read_json(self, path: Path) -> dict[str, Any] | None: + """Read a JSON file, returning None if the file does not exist. + + Raises: + ReportStoreError: On JSON parse error (file exists but is corrupt). + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Macro Scan + # ------------------------------------------------------------------ + + def save_scan(self, date: str, data: dict[str, Any]) -> Path: + """Save macro scan summary JSON. + + Path: ``{base_dir}/daily/{date}/market/macro_scan_summary.json`` + + Args: + date: ISO date string, e.g. ``"2026-03-20"``. + data: Scan output dict (typically the macro_scan_summary). + + Returns: + Path of the written file. + """ + # TODO: implement + raise NotImplementedError + + def load_scan(self, date: str) -> dict[str, Any] | None: + """Load macro scan summary. Returns None if the file does not exist.""" + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Per-Ticker Analysis + # ------------------------------------------------------------------ + + def save_analysis(self, date: str, ticker: str, data: dict[str, Any]) -> Path: + """Save per-ticker analysis report as JSON. + + Path: ``{base_dir}/daily/{date}/{TICKER}/complete_report.json`` + + Args: + date: ISO date string. + ticker: Ticker symbol (stored as uppercase). + data: Analysis output dict. + """ + # TODO: implement + raise NotImplementedError + + def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None: + """Load per-ticker analysis JSON. Returns None if the file does not exist.""" + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Holding Reviews + # ------------------------------------------------------------------ + + def save_holding_review( + self, + date: str, + ticker: str, + data: dict[str, Any], + ) -> Path: + """Save holding reviewer output for one ticker. + + Path: ``{base_dir}/daily/{date}/portfolio/{TICKER}_holding_review.json`` + + Args: + date: ISO date string. + ticker: Ticker symbol (stored as uppercase). + data: HoldingReviewerAgent output dict. + """ + # TODO: implement + raise NotImplementedError + + def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None: + """Load holding review output. Returns None if the file does not exist.""" + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Risk Metrics + # ------------------------------------------------------------------ + + def save_risk_metrics( + self, + date: str, + portfolio_id: str, + data: dict[str, Any], + ) -> Path: + """Save risk computation results. + + Path: ``{base_dir}/daily/{date}/portfolio/{portfolio_id}_risk_metrics.json`` + + Args: + date: ISO date string. + portfolio_id: UUID of the target portfolio. + data: Risk metrics dict (Sharpe, Sortino, VaR, etc.). + """ + # TODO: implement + raise NotImplementedError + + def load_risk_metrics( + self, + date: str, + portfolio_id: str, + ) -> dict[str, Any] | None: + """Load risk metrics. Returns None if the file does not exist.""" + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # PM Decisions + # ------------------------------------------------------------------ + + def save_pm_decision( + self, + date: str, + portfolio_id: str, + data: dict[str, Any], + markdown: str | None = None, + ) -> Path: + """Save PM agent decision. + + JSON path: ``{base_dir}/daily/{date}/portfolio/{portfolio_id}_pm_decision.json`` + MD path: ``{base_dir}/daily/{date}/portfolio/{portfolio_id}_pm_decision.md`` + (written only when ``markdown`` is not None) + + Args: + date: ISO date string. + portfolio_id: UUID of the target portfolio. + data: PM decision dict (sells, buys, holds, rationale, …). + markdown: Optional human-readable version; written when provided. + + Returns: + Path of the written JSON file. + """ + # TODO: implement + raise NotImplementedError + + def load_pm_decision( + self, + date: str, + portfolio_id: str, + ) -> dict[str, Any] | None: + """Load PM decision JSON. Returns None if the file does not exist.""" + # TODO: implement + raise NotImplementedError + + def list_pm_decisions(self, portfolio_id: str) -> list[Path]: + """Return all saved PM decision JSON paths for portfolio_id, newest first. + + Scans ``{base_dir}/daily/*/portfolio/{portfolio_id}_pm_decision.json``. + + Args: + portfolio_id: UUID of the target portfolio. + + Returns: + Sorted list of Path objects, newest date first. + """ + # TODO: implement + raise NotImplementedError diff --git a/tradingagents/portfolio/repository.py b/tradingagents/portfolio/repository.py new file mode 100644 index 00000000..123803fd --- /dev/null +++ b/tradingagents/portfolio/repository.py @@ -0,0 +1,315 @@ +"""Unified data-access façade for the Portfolio Manager. + +``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and +``ReportStore`` (filesystem documents) into a single, business-logic-aware +interface. + +Callers should **only** interact with ``PortfolioRepository`` — do not use +``SupabaseClient`` or ``ReportStore`` directly from outside this package. + +Usage:: + + from tradingagents.portfolio import PortfolioRepository + + repo = PortfolioRepository() + + # Create a portfolio + portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0) + + # Buy shares + holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50) + + # Sell shares + repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=25, price=200.00) + + # Snapshot + snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 200.00}) + +See ``docs/portfolio/04_repository_api.md`` for full API documentation. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from tradingagents.portfolio.exceptions import ( + HoldingNotFoundError, + InsufficientCashError, + InsufficientSharesError, +) +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) +from tradingagents.portfolio.report_store import ReportStore +from tradingagents.portfolio.supabase_client import SupabaseClient + + +class PortfolioRepository: + """Unified façade over SupabaseClient and ReportStore. + + Implements business logic for: + - Average cost basis updates on repeated buys + - Cash deduction / credit on trades + - Constraint enforcement (cash, position size) + - Snapshot management + """ + + def __init__( + self, + client: SupabaseClient | None = None, + store: ReportStore | None = None, + config: dict[str, Any] | None = None, + ) -> None: + """Initialise the repository. + + Args: + client: SupabaseClient instance. Uses ``SupabaseClient.get_instance()`` + when None. + store: ReportStore instance. Creates a default instance when None. + config: Portfolio config dict. Uses ``get_portfolio_config()`` when None. + """ + # TODO: implement — resolve defaults, store as self._client, self._store, self._cfg + raise NotImplementedError + + # ------------------------------------------------------------------ + # Portfolio lifecycle + # ------------------------------------------------------------------ + + def create_portfolio( + self, + name: str, + initial_cash: float, + currency: str = "USD", + ) -> Portfolio: + """Create a new portfolio with the given starting capital. + + Generates a UUID for ``portfolio_id``. Sets ``cash = initial_cash``. + + Args: + name: Human-readable portfolio name. + initial_cash: Starting capital in USD (or configured currency). + currency: ISO 4217 currency code. + + Returns: + Persisted Portfolio instance. + + Raises: + DuplicatePortfolioError: If the name is already in use. + ValueError: If ``initial_cash <= 0``. + """ + # TODO: implement + raise NotImplementedError + + def get_portfolio(self, portfolio_id: str) -> Portfolio: + """Fetch a portfolio by ID. + + Args: + portfolio_id: UUID of the target portfolio. + + Raises: + PortfolioNotFoundError: If not found. + """ + # TODO: implement + raise NotImplementedError + + def get_portfolio_with_holdings( + self, + portfolio_id: str, + prices: dict[str, float] | None = None, + ) -> tuple[Portfolio, list[Holding]]: + """Fetch portfolio + all holdings, optionally enriched with current prices. + + Args: + portfolio_id: UUID of the target portfolio. + prices: Optional ``{ticker: current_price}`` dict. When provided, + holdings are enriched and ``Portfolio.enrich()`` is called. + + Returns: + ``(Portfolio, list[Holding])`` — enriched when prices are supplied. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Holdings management + # ------------------------------------------------------------------ + + def add_holding( + self, + portfolio_id: str, + ticker: str, + shares: float, + price: float, + sector: str | None = None, + industry: str | None = None, + ) -> Holding: + """Buy shares and update portfolio cash and holdings. + + Business logic: + - Raises ``InsufficientCashError`` if ``portfolio.cash < shares * price`` + - If holding already exists: updates ``avg_cost`` using weighted average + - ``portfolio.cash -= shares * price`` + - Records a BUY trade automatically + + Avg cost formula:: + + new_avg_cost = (old_shares * old_avg_cost + new_shares * price) + / (old_shares + new_shares) + + Args: + portfolio_id: UUID of the target portfolio. + ticker: Ticker symbol. + shares: Number of shares to buy (must be > 0). + price: Execution price per share. + sector: Optional GICS sector name. + industry: Optional GICS industry name. + + Returns: + Updated or created Holding. + + Raises: + InsufficientCashError: If cash would go negative. + PortfolioNotFoundError: If portfolio_id does not exist. + ValueError: If shares <= 0 or price <= 0. + """ + # TODO: implement + raise NotImplementedError + + def remove_holding( + self, + portfolio_id: str, + ticker: str, + shares: float, + price: float, + ) -> Holding | None: + """Sell shares and update portfolio cash and holdings. + + Business logic: + - Raises ``HoldingNotFoundError`` if no holding exists for ticker + - Raises ``InsufficientSharesError`` if ``holding.shares < shares`` + - If ``shares == holding.shares``: deletes the holding row, returns None + - Otherwise: decrements ``holding.shares`` (avg_cost unchanged on sell) + - ``portfolio.cash += shares * price`` + - Records a SELL trade automatically + + Args: + portfolio_id: UUID of the target portfolio. + ticker: Ticker symbol. + shares: Number of shares to sell (must be > 0). + price: Execution price per share. + + Returns: + Updated Holding, or None if the position was fully closed. + + Raises: + HoldingNotFoundError: If no holding exists for this ticker. + InsufficientSharesError: If holding.shares < shares. + PortfolioNotFoundError: If portfolio_id does not exist. + ValueError: If shares <= 0 or price <= 0. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Snapshots + # ------------------------------------------------------------------ + + def take_snapshot( + self, + portfolio_id: str, + prices: dict[str, float], + ) -> PortfolioSnapshot: + """Take an immutable snapshot of the current portfolio state. + + Fetches all holdings, enriches them with ``prices``, computes + ``total_value``, then persists via ``SupabaseClient.save_snapshot()``. + + Args: + portfolio_id: UUID of the target portfolio. + prices: ``{ticker: current_price}`` for all held tickers. + + Returns: + Persisted PortfolioSnapshot. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Report convenience methods + # ------------------------------------------------------------------ + + def save_pm_decision( + self, + portfolio_id: str, + date: str, + decision: dict[str, Any], + markdown: str | None = None, + ) -> Path: + """Save a PM agent decision and update portfolio.report_path. + + Delegates to ``ReportStore.save_pm_decision()`` then updates the + ``portfolio.report_path`` column in Supabase to point to the daily + portfolio directory. + + Args: + portfolio_id: UUID of the target portfolio. + date: ISO date string, e.g. ``"2026-03-20"``. + decision: PM decision dict. + markdown: Optional human-readable markdown version. + + Returns: + Path of the written JSON file. + """ + # TODO: implement + raise NotImplementedError + + def load_pm_decision( + self, + portfolio_id: str, + date: str, + ) -> dict[str, Any] | None: + """Load a PM decision JSON. Returns None if not found. + + Args: + portfolio_id: UUID of the target portfolio. + date: ISO date string. + """ + # TODO: implement + raise NotImplementedError + + def save_risk_metrics( + self, + portfolio_id: str, + date: str, + metrics: dict[str, Any], + ) -> Path: + """Save risk computation results. Delegates to ReportStore. + + Args: + portfolio_id: UUID of the target portfolio. + date: ISO date string. + metrics: Risk metrics dict (Sharpe, Sortino, VaR, beta, etc.). + + Returns: + Path of the written file. + """ + # TODO: implement + raise NotImplementedError + + def load_risk_metrics( + self, + portfolio_id: str, + date: str, + ) -> dict[str, Any] | None: + """Load risk metrics. Returns None if not found. + + Args: + portfolio_id: UUID of the target portfolio. + date: ISO date string. + """ + # TODO: implement + raise NotImplementedError diff --git a/tradingagents/portfolio/supabase_client.py b/tradingagents/portfolio/supabase_client.py new file mode 100644 index 00000000..f2d7b0e6 --- /dev/null +++ b/tradingagents/portfolio/supabase_client.py @@ -0,0 +1,252 @@ +"""Supabase database client for the Portfolio Manager. + +Thin wrapper around ``supabase-py`` that: +- Provides a singleton connection (one client per process) +- Translates Supabase / HTTP errors into domain exceptions +- Converts raw DB rows into typed model instances + +Usage:: + + from tradingagents.portfolio.supabase_client import SupabaseClient + from tradingagents.portfolio.models import Portfolio + + client = SupabaseClient.get_instance() + portfolio = client.get_portfolio("some-uuid") + +Configuration (read from environment via ``get_portfolio_config()``): + SUPABASE_URL — Supabase project URL + SUPABASE_KEY — Supabase anon or service-role key +""" + +from __future__ import annotations + +from tradingagents.portfolio.exceptions import ( + DuplicatePortfolioError, + HoldingNotFoundError, + PortfolioNotFoundError, +) +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) + + +class SupabaseClient: + """Singleton Supabase CRUD client for portfolio data. + + All public methods translate Supabase / HTTP errors into domain exceptions + and return typed model instances. + + Do not instantiate directly — use ``SupabaseClient.get_instance()``. + """ + + _instance: "SupabaseClient | None" = None + + def __init__(self, url: str, key: str) -> None: + """Initialise the Supabase client. + + Args: + url: Supabase project URL. + key: Supabase anon or service-role key. + """ + # TODO: implement — create supabase.create_client(url, key) + raise NotImplementedError + + @classmethod + def get_instance(cls) -> "SupabaseClient": + """Return the singleton instance, creating it if necessary. + + Reads SUPABASE_URL and SUPABASE_KEY from ``get_portfolio_config()``. + + Raises: + PortfolioError: If SUPABASE_URL or SUPABASE_KEY are not configured. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Portfolio CRUD + # ------------------------------------------------------------------ + + def create_portfolio(self, portfolio: Portfolio) -> Portfolio: + """Insert a new portfolio row. + + Args: + portfolio: Portfolio instance with all required fields set. + + Returns: + Portfolio with DB-assigned timestamps. + + Raises: + DuplicatePortfolioError: If portfolio_id already exists. + """ + # TODO: implement + raise NotImplementedError + + def get_portfolio(self, portfolio_id: str) -> Portfolio: + """Fetch a portfolio by ID. + + Args: + portfolio_id: UUID of the target portfolio. + + Returns: + Portfolio instance. + + Raises: + PortfolioNotFoundError: If no portfolio has that ID. + """ + # TODO: implement + raise NotImplementedError + + def list_portfolios(self) -> list[Portfolio]: + """Return all portfolios ordered by created_at DESC.""" + # TODO: implement + raise NotImplementedError + + def update_portfolio(self, portfolio: Portfolio) -> Portfolio: + """Update mutable portfolio fields (cash, report_path, metadata). + + Args: + portfolio: Portfolio with updated field values. + + Returns: + Updated Portfolio with refreshed updated_at. + + Raises: + PortfolioNotFoundError: If portfolio_id does not exist. + """ + # TODO: implement + raise NotImplementedError + + def delete_portfolio(self, portfolio_id: str) -> None: + """Delete a portfolio and all associated data (CASCADE). + + Args: + portfolio_id: UUID of the portfolio to delete. + + Raises: + PortfolioNotFoundError: If portfolio_id does not exist. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Holdings CRUD + # ------------------------------------------------------------------ + + def upsert_holding(self, holding: Holding) -> Holding: + """Insert or update a holding row (upsert on portfolio_id + ticker). + + Args: + holding: Holding instance with all required fields set. + + Returns: + Holding with DB-assigned / refreshed timestamps. + """ + # TODO: implement + raise NotImplementedError + + def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None: + """Return the holding for (portfolio_id, ticker), or None if not found. + + Args: + portfolio_id: UUID of the target portfolio. + ticker: Ticker symbol (case-insensitive, stored as uppercase). + """ + # TODO: implement + raise NotImplementedError + + def list_holdings(self, portfolio_id: str) -> list[Holding]: + """Return all holdings for a portfolio ordered by cost_basis DESC. + + Args: + portfolio_id: UUID of the target portfolio. + """ + # TODO: implement + raise NotImplementedError + + def delete_holding(self, portfolio_id: str, ticker: str) -> None: + """Delete the holding for (portfolio_id, ticker). + + Args: + portfolio_id: UUID of the target portfolio. + ticker: Ticker symbol. + + Raises: + HoldingNotFoundError: If no such holding exists. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Trades + # ------------------------------------------------------------------ + + def record_trade(self, trade: Trade) -> Trade: + """Insert a new trade record. Immutable — no update method. + + Args: + trade: Trade instance with all required fields set. + + Returns: + Trade with DB-assigned trade_id and trade_date. + """ + # TODO: implement + raise NotImplementedError + + def list_trades( + self, + portfolio_id: str, + ticker: str | None = None, + limit: int = 100, + ) -> list[Trade]: + """Return recent trades for a portfolio, newest first. + + Args: + portfolio_id: Filter by portfolio. + ticker: Optional additional filter by ticker symbol. + limit: Maximum number of rows to return. + """ + # TODO: implement + raise NotImplementedError + + # ------------------------------------------------------------------ + # Snapshots + # ------------------------------------------------------------------ + + def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot: + """Insert a new immutable portfolio snapshot. + + Args: + snapshot: PortfolioSnapshot with all required fields set. + + Returns: + Snapshot with DB-assigned snapshot_id and snapshot_date. + """ + # TODO: implement + raise NotImplementedError + + def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None: + """Return the most recent snapshot for a portfolio, or None. + + Args: + portfolio_id: UUID of the target portfolio. + """ + # TODO: implement + raise NotImplementedError + + def list_snapshots( + self, + portfolio_id: str, + limit: int = 30, + ) -> list[PortfolioSnapshot]: + """Return snapshots newest-first up to limit. + + Args: + portfolio_id: UUID of the target portfolio. + limit: Maximum number of snapshots to return. + """ + # TODO: implement + raise NotImplementedError From 7ea9866d1d31e96ebd981c3e2d938064790d3261 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 10:48:40 +0000 Subject: [PATCH 3/5] =?UTF-8?q?docs:=20ADR-012=20=E2=80=94=20raw=20supabas?= =?UTF-8?q?e-py=20over=20Prisma/SQLAlchemy=20for=20portfolio=20data=20laye?= =?UTF-8?q?r?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> --- docs/agent/decisions/012-portfolio-no-orm.md | 92 ++++++++++++++++++++ docs/portfolio/00_overview.md | 26 ++++++ tradingagents/portfolio/supabase_client.py | 10 ++- 3 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 docs/agent/decisions/012-portfolio-no-orm.md diff --git a/docs/agent/decisions/012-portfolio-no-orm.md b/docs/agent/decisions/012-portfolio-no-orm.md new file mode 100644 index 00000000..58835368 --- /dev/null +++ b/docs/agent/decisions/012-portfolio-no-orm.md @@ -0,0 +1,92 @@ +--- +type: decision +status: active +date: 2026-03-20 +agent_author: "claude" +tags: [portfolio, database, supabase, orm, prisma] +related_files: + - tradingagents/portfolio/supabase_client.py + - tradingagents/portfolio/repository.py + - tradingagents/portfolio/migrations/001_initial_schema.sql +--- + +## Context + +When designing the Portfolio Manager data layer (Phase 1), the question arose: +should we use an ORM (specifically **Prisma**) or keep the raw `supabase-py` +client that the scaffolding already plans to use? + +The options considered were: + +| Option | Description | +|--------|-------------| +| **Raw `supabase-py`** (chosen) | Direct Supabase PostgREST client, builder-pattern API | +| **Prisma Python** (`prisma-client-py`) | Code-generated type-safe ORM backed by Node.js | +| **SQLAlchemy** | Full ORM with Core + ORM layers, Alembic migrations | + +## The Decision + +**Use raw `supabase-py` without an ORM for the portfolio data layer.** + +The data access layer (`supabase_client.py`) wraps the Supabase client directly. +Our own `Portfolio`, `Holding`, `Trade`, and `PortfolioSnapshot` dataclasses +provide the type-safety layer; serialisation is handled by `to_dict()` / +`from_dict()` on each model. + +## Why Not Prisma + +1. **Node.js runtime dependency** — `prisma-client-py` uses Prisma's Node.js + engine at code-generation time. This adds a non-Python runtime requirement + to a Python-only project. + +2. **Conflicts with Supabase's migration tooling** — the project already uses + Supabase's SQL migration files (`migrations/001_initial_schema.sql`) and the + Supabase dashboard for schema changes. Prisma's `prisma migrate` maintains + its own shadow database and migration state, creating two competing systems. + +3. **Code generation build step** — every schema change requires running + `prisma generate` before the Python code works. This complicates CI, local + setup, and agent-driven development. + +4. **Overkill for 4 tables** — the portfolio schema has exactly 4 tables with + straightforward CRUD. Prisma's relationship traversal and complex query + features offer no benefit here. + +## Why Not SQLAlchemy + +1. **Not using a local database** — the database is managed by Supabase (hosted + PostgreSQL). SQLAlchemy's connection-pooling and engine management are + designed for direct database connections, which bypass Supabase's PostgREST + API and Row Level Security. + +2. **Extra dependency** — SQLAlchemy + Alembic would be significant new + dependencies for a non-DB-heavy app. + +## Why Raw `supabase-py` Is Sufficient + +- `supabase-py` provides a clean builder-pattern API: + `client.table("holdings").select("*").eq("portfolio_id", id).execute()` +- Our dataclasses already provide compile-time type safety and lossless + serialisation; the client only handles transport. +- Migrations are plain SQL files — readable, versionable, Supabase-native. +- `SupabaseClient` is a thin singleton wrapper that translates HTTP errors into + domain exceptions — this gives us the ORM-like error-handling benefit without + the complexity. + +## Constraints + +- **Do not** add an ORM dependency (`prisma-client-py`, `sqlalchemy`, `tortoise-orm`) + to `pyproject.toml` without revisiting this decision. +- **Do not** bypass `SupabaseClient` by importing `supabase` directly in other + modules — always go through `PortfolioRepository`. +- If the schema grows beyond ~10 tables or requires complex multi-table joins, + revisit this decision and consider SQLAlchemy Core (not the ORM layer) with + direct `asyncpg` connections. + +## Actionable Rules + +- All DB access goes through `PortfolioRepository` → `SupabaseClient`. +- Migrations are `.sql` files in `tradingagents/portfolio/migrations/`, run via + the Supabase SQL Editor or `supabase db push`. +- Type safety comes from dataclass `to_dict()` / `from_dict()` — not from a + code-generated ORM schema. diff --git a/docs/portfolio/00_overview.md b/docs/portfolio/00_overview.md index a6f7ab46..741a2605 100644 --- a/docs/portfolio/00_overview.md +++ b/docs/portfolio/00_overview.md @@ -72,6 +72,32 @@ investment portfolio end-to-end. It performs the following actions in sequence: The `report_path` column in the `portfolios` table points to the daily portfolio subdirectory on disk: `reports/daily/{date}/portfolio/`. +### Data Access Layer: raw `supabase-py` (no ORM) + +The Python code talks to Supabase through the raw `supabase-py` client — **no +ORM** (Prisma, SQLAlchemy, etc.) is used. + +**Why not Prisma?** +- `prisma-client-py` requires a Node.js runtime for code generation — an + extra non-Python dependency in a Python-only project. +- Prisma's `prisma migrate` conflicts with Supabase's own SQL migration tooling + (we use `.sql` files in `tradingagents/portfolio/migrations/`). +- 4 tables with straightforward CRUD don't benefit from a code-generated ORM. + +**Why not SQLAlchemy?** +- Supabase is accessed via PostgREST (HTTP API), not a direct TCP database + connection. SQLAlchemy is designed for direct connections and would bypass + Supabase's Row Level Security. +- Extra dependency overhead for a non-DB-heavy feature. + +**`supabase-py` is sufficient because:** +- Its builder-pattern API (`client.table("holdings").select("*").eq(...)`) + covers all needed queries cleanly. +- Our own dataclasses handle type-safety via `to_dict()` / `from_dict()`. +- Plain SQL migration files are readable, versionable, and Supabase-native. + +> Full rationale: `docs/agent/decisions/012-portfolio-no-orm.md` + --- ## 5-Phase Workflow diff --git a/tradingagents/portfolio/supabase_client.py b/tradingagents/portfolio/supabase_client.py index f2d7b0e6..efd87b34 100644 --- a/tradingagents/portfolio/supabase_client.py +++ b/tradingagents/portfolio/supabase_client.py @@ -1,9 +1,15 @@ """Supabase database client for the Portfolio Manager. -Thin wrapper around ``supabase-py`` that: +Thin wrapper around ``supabase-py`` (no ORM) that: - Provides a singleton connection (one client per process) - Translates Supabase / HTTP errors into domain exceptions -- Converts raw DB rows into typed model instances +- Converts raw DB rows into typed model instances via ``Model.from_dict()`` + +**No ORM is used here by design** — see +``docs/agent/decisions/012-portfolio-no-orm.md`` for the full rationale. +In short: ``supabase-py``'s builder-pattern API is sufficient for 4 tables; +Prisma and SQLAlchemy add build-step / runtime complexity that isn't warranted +for this non-DB-heavy feature. Usage:: From aa4dcdeb806684d312392b0a0d66b0424edc9c65 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 11:16:39 +0000 Subject: [PATCH 4/5] feat: implement Portfolio models, ReportStore, and tests; fix SQL constraint Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> --- tests/portfolio/conftest.py | 85 ++++++-- tests/portfolio/test_models.py | 183 ++++++++++++++---- tests/portfolio/test_report_store.py | 115 +++++++---- .../migrations/001_initial_schema.sql | 6 +- tradingagents/portfolio/models.py | 147 ++++++++++++-- tradingagents/portfolio/report_store.py | 76 +++++--- 6 files changed, 464 insertions(+), 148 deletions(-) diff --git a/tests/portfolio/conftest.py b/tests/portfolio/conftest.py index bbbf0d78..073d318d 100644 --- a/tests/portfolio/conftest.py +++ b/tests/portfolio/conftest.py @@ -17,12 +17,20 @@ Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when from __future__ import annotations import os -import uuid from pathlib import Path from unittest.mock import MagicMock import pytest +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) +from tradingagents.portfolio.report_store import ReportStore +from tradingagents.portfolio.supabase_client import SupabaseClient + # --------------------------------------------------------------------------- # Skip marker for Supabase integration tests # --------------------------------------------------------------------------- @@ -51,31 +59,72 @@ def sample_holding_id() -> str: @pytest.fixture -def sample_portfolio(sample_portfolio_id: str): +def sample_portfolio(sample_portfolio_id: str) -> Portfolio: """Return an unsaved Portfolio instance for testing.""" - # TODO: implement — construct a Portfolio dataclass with test values - raise NotImplementedError + return Portfolio( + portfolio_id=sample_portfolio_id, + name="Test Portfolio", + cash=50_000.0, + initial_cash=100_000.0, + currency="USD", + created_at="2026-03-20T00:00:00Z", + updated_at="2026-03-20T00:00:00Z", + report_path="reports/daily/2026-03-20/portfolio", + metadata={"strategy": "test"}, + ) @pytest.fixture -def sample_holding(sample_portfolio_id: str, sample_holding_id: str): +def sample_holding(sample_portfolio_id: str, sample_holding_id: str) -> Holding: """Return an unsaved Holding instance for testing.""" - # TODO: implement — construct a Holding dataclass with test values - raise NotImplementedError + return Holding( + holding_id=sample_holding_id, + portfolio_id=sample_portfolio_id, + ticker="AAPL", + shares=100.0, + avg_cost=150.0, + sector="Technology", + industry="Consumer Electronics", + created_at="2026-03-20T00:00:00Z", + updated_at="2026-03-20T00:00:00Z", + ) @pytest.fixture -def sample_trade(sample_portfolio_id: str): +def sample_trade(sample_portfolio_id: str) -> Trade: """Return an unsaved Trade instance for testing.""" - # TODO: implement — construct a Trade dataclass with test values - raise NotImplementedError + return Trade( + trade_id="33333333-3333-3333-3333-333333333333", + portfolio_id=sample_portfolio_id, + ticker="AAPL", + action="BUY", + shares=100.0, + price=150.0, + total_value=15_000.0, + trade_date="2026-03-20T10:00:00Z", + rationale="Strong momentum signal", + signal_source="scanner", + metadata={"confidence": 0.85}, + ) @pytest.fixture -def sample_snapshot(sample_portfolio_id: str): +def sample_snapshot(sample_portfolio_id: str) -> PortfolioSnapshot: """Return an unsaved PortfolioSnapshot instance for testing.""" - # TODO: implement — construct a PortfolioSnapshot dataclass with test values - raise NotImplementedError + return PortfolioSnapshot( + snapshot_id="44444444-4444-4444-4444-444444444444", + portfolio_id=sample_portfolio_id, + snapshot_date="2026-03-20", + total_value=115_000.0, + cash=50_000.0, + equity_value=65_000.0, + num_positions=2, + holdings_snapshot=[ + {"ticker": "AAPL", "shares": 100.0, "avg_cost": 150.0}, + {"ticker": "MSFT", "shares": 50.0, "avg_cost": 300.0}, + ], + metadata={"note": "end of day snapshot"}, + ) # --------------------------------------------------------------------------- @@ -92,10 +141,9 @@ def tmp_reports(tmp_path: Path) -> Path: @pytest.fixture -def report_store(tmp_reports: Path): +def report_store(tmp_reports: Path) -> ReportStore: """ReportStore instance backed by a temporary directory.""" - # TODO: implement — return ReportStore(base_dir=tmp_reports) - raise NotImplementedError + return ReportStore(base_dir=tmp_reports) # --------------------------------------------------------------------------- @@ -104,7 +152,6 @@ def report_store(tmp_reports: Path): @pytest.fixture -def mock_supabase_client(): +def mock_supabase_client() -> MagicMock: """MagicMock of SupabaseClient for unit tests that don't hit the DB.""" - # TODO: implement — return MagicMock(spec=SupabaseClient) - raise NotImplementedError + return MagicMock(spec=SupabaseClient) diff --git a/tests/portfolio/test_models.py b/tests/portfolio/test_models.py index 72f3c798..bcd2258a 100644 --- a/tests/portfolio/test_models.py +++ b/tests/portfolio/test_models.py @@ -16,6 +16,13 @@ from __future__ import annotations import pytest +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) + # --------------------------------------------------------------------------- # Portfolio round-trip @@ -24,19 +31,41 @@ import pytest def test_portfolio_to_dict_round_trip(sample_portfolio): """Portfolio.to_dict() -> Portfolio.from_dict() must be lossless.""" - # TODO: implement - # d = sample_portfolio.to_dict() - # restored = Portfolio.from_dict(d) - # assert restored.portfolio_id == sample_portfolio.portfolio_id - # assert restored.cash == sample_portfolio.cash - # ... all stored fields - raise NotImplementedError + d = sample_portfolio.to_dict() + restored = Portfolio.from_dict(d) + assert restored.portfolio_id == sample_portfolio.portfolio_id + assert restored.name == sample_portfolio.name + assert restored.cash == sample_portfolio.cash + assert restored.initial_cash == sample_portfolio.initial_cash + assert restored.currency == sample_portfolio.currency + assert restored.created_at == sample_portfolio.created_at + assert restored.updated_at == sample_portfolio.updated_at + assert restored.report_path == sample_portfolio.report_path + assert restored.metadata == sample_portfolio.metadata def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio): """to_dict() must not include computed fields (total_value, equity_value, cash_pct).""" - # TODO: implement - raise NotImplementedError + d = sample_portfolio.to_dict() + assert "total_value" not in d + assert "equity_value" not in d + assert "cash_pct" not in d + + +def test_portfolio_from_dict_defaults_optional_fields(): + """from_dict() must tolerate missing optional fields.""" + minimal = { + "portfolio_id": "pid-1", + "name": "Minimal", + "cash": 1000.0, + "initial_cash": 1000.0, + } + p = Portfolio.from_dict(minimal) + assert p.currency == "USD" + assert p.created_at == "" + assert p.updated_at == "" + assert p.report_path is None + assert p.metadata == {} # --------------------------------------------------------------------------- @@ -46,14 +75,23 @@ def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio): def test_holding_to_dict_round_trip(sample_holding): """Holding.to_dict() -> Holding.from_dict() must be lossless.""" - # TODO: implement - raise NotImplementedError + d = sample_holding.to_dict() + restored = Holding.from_dict(d) + assert restored.holding_id == sample_holding.holding_id + assert restored.portfolio_id == sample_holding.portfolio_id + assert restored.ticker == sample_holding.ticker + assert restored.shares == sample_holding.shares + assert restored.avg_cost == sample_holding.avg_cost + assert restored.sector == sample_holding.sector + assert restored.industry == sample_holding.industry def test_holding_to_dict_excludes_runtime_fields(sample_holding): """to_dict() must not include current_price, current_value, weight, etc.""" - # TODO: implement - raise NotImplementedError + d = sample_holding.to_dict() + for field in ("current_price", "current_value", "cost_basis", + "unrealized_pnl", "unrealized_pnl_pct", "weight"): + assert field not in d # --------------------------------------------------------------------------- @@ -63,8 +101,19 @@ def test_holding_to_dict_excludes_runtime_fields(sample_holding): def test_trade_to_dict_round_trip(sample_trade): """Trade.to_dict() -> Trade.from_dict() must be lossless.""" - # TODO: implement - raise NotImplementedError + d = sample_trade.to_dict() + restored = Trade.from_dict(d) + assert restored.trade_id == sample_trade.trade_id + assert restored.portfolio_id == sample_trade.portfolio_id + assert restored.ticker == sample_trade.ticker + assert restored.action == sample_trade.action + assert restored.shares == sample_trade.shares + assert restored.price == sample_trade.price + assert restored.total_value == sample_trade.total_value + assert restored.trade_date == sample_trade.trade_date + assert restored.rationale == sample_trade.rationale + assert restored.signal_source == sample_trade.signal_source + assert restored.metadata == sample_trade.metadata # --------------------------------------------------------------------------- @@ -74,8 +123,35 @@ def test_trade_to_dict_round_trip(sample_trade): def test_snapshot_to_dict_round_trip(sample_snapshot): """PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip.""" - # TODO: implement - raise NotImplementedError + d = sample_snapshot.to_dict() + restored = PortfolioSnapshot.from_dict(d) + assert restored.snapshot_id == sample_snapshot.snapshot_id + assert restored.portfolio_id == sample_snapshot.portfolio_id + assert restored.snapshot_date == sample_snapshot.snapshot_date + assert restored.total_value == sample_snapshot.total_value + assert restored.cash == sample_snapshot.cash + assert restored.equity_value == sample_snapshot.equity_value + assert restored.num_positions == sample_snapshot.num_positions + assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot + assert restored.metadata == sample_snapshot.metadata + + +def test_snapshot_from_dict_parses_holdings_snapshot_json_string(): + """from_dict() must parse holdings_snapshot when it arrives as a JSON string.""" + import json + holdings = [{"ticker": "AAPL", "shares": 10.0}] + data = { + "snapshot_id": "snap-1", + "portfolio_id": "pid-1", + "snapshot_date": "2026-03-20", + "total_value": 110_000.0, + "cash": 10_000.0, + "equity_value": 100_000.0, + "num_positions": 1, + "holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase + } + snap = PortfolioSnapshot.from_dict(data) + assert snap.holdings_snapshot == holdings # --------------------------------------------------------------------------- @@ -85,34 +161,50 @@ def test_snapshot_to_dict_round_trip(sample_snapshot): def test_holding_enrich_computes_current_value(sample_holding): """enrich() must set current_value = current_price * shares.""" - # TODO: implement - # sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) - # assert sample_holding.current_value == 200.0 * sample_holding.shares - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + assert sample_holding.current_value == 200.0 * sample_holding.shares def test_holding_enrich_computes_unrealized_pnl(sample_holding): """enrich() must set unrealized_pnl = current_value - cost_basis.""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + expected_cost_basis = sample_holding.avg_cost * sample_holding.shares + expected_pnl = sample_holding.current_value - expected_cost_basis + assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl) + + +def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding): + """enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis.""" + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + cost_basis = sample_holding.avg_cost * sample_holding.shares + expected_pct = sample_holding.unrealized_pnl / cost_basis + assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct) def test_holding_enrich_computes_weight(sample_holding): """enrich() must set weight = current_value / portfolio_total_value.""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + expected_weight = sample_holding.current_value / 100_000.0 + assert sample_holding.weight == pytest.approx(expected_weight) + + +def test_holding_enrich_returns_self(sample_holding): + """enrich() must return self for chaining.""" + result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + assert result is sample_holding def test_holding_enrich_handles_zero_cost(sample_holding): """When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError).""" - # TODO: implement - raise NotImplementedError + sample_holding.avg_cost = 0.0 + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + assert sample_holding.unrealized_pnl_pct == 0.0 def test_holding_enrich_handles_zero_portfolio_value(sample_holding): """When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError).""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0) + assert sample_holding.weight == 0.0 # --------------------------------------------------------------------------- @@ -122,11 +214,36 @@ def test_holding_enrich_handles_zero_portfolio_value(sample_holding): def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding): """Portfolio.enrich() must compute total_value = cash + sum(holding.current_value).""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich() + sample_portfolio.enrich([sample_holding]) + expected_equity = 200.0 * sample_holding.shares + assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity) + + +def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding): + """Portfolio.enrich() must set equity_value = sum(holding.current_value).""" + sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich() + sample_portfolio.enrich([sample_holding]) + assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares) def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding): """Portfolio.enrich() must compute cash_pct = cash / total_value.""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich() + sample_portfolio.enrich([sample_holding]) + expected_pct = sample_portfolio.cash / sample_portfolio.total_value + assert sample_portfolio.cash_pct == pytest.approx(expected_pct) + + +def test_portfolio_enrich_returns_self(sample_portfolio): + """enrich() must return self for chaining.""" + result = sample_portfolio.enrich([]) + assert result is sample_portfolio + + +def test_portfolio_enrich_no_holdings(sample_portfolio): + """Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash.""" + sample_portfolio.enrich([]) + assert sample_portfolio.equity_value == 0.0 + assert sample_portfolio.total_value == sample_portfolio.cash + assert sample_portfolio.cash_pct == 1.0 diff --git a/tests/portfolio/test_report_store.py b/tests/portfolio/test_report_store.py index ba111799..9f042a26 100644 --- a/tests/portfolio/test_report_store.py +++ b/tests/portfolio/test_report_store.py @@ -12,10 +12,14 @@ Run:: from __future__ import annotations +import json from pathlib import Path import pytest +from tradingagents.portfolio.exceptions import ReportStoreError +from tradingagents.portfolio.report_store import ReportStore + # --------------------------------------------------------------------------- # Macro scan @@ -24,19 +28,17 @@ import pytest def test_save_and_load_scan(report_store, tmp_reports): """save_scan() then load_scan() must return the original data.""" - # TODO: implement - # data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"} - # path = report_store.save_scan("2026-03-20", data) - # assert path.exists() - # loaded = report_store.load_scan("2026-03-20") - # assert loaded == data - raise NotImplementedError + data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"} + path = report_store.save_scan("2026-03-20", data) + assert path.exists() + loaded = report_store.load_scan("2026-03-20") + assert loaded == data def test_load_scan_returns_none_for_missing_file(report_store): """load_scan() must return None when the file does not exist.""" - # TODO: implement - raise NotImplementedError + result = report_store.load_scan("1900-01-01") + assert result is None # --------------------------------------------------------------------------- @@ -46,14 +48,21 @@ def test_load_scan_returns_none_for_missing_file(report_store): def test_save_and_load_analysis(report_store): """save_analysis() then load_analysis() must return the original data.""" - # TODO: implement - raise NotImplementedError + data = {"ticker": "AAPL", "recommendation": "BUY", "score": 0.92} + report_store.save_analysis("2026-03-20", "AAPL", data) + loaded = report_store.load_analysis("2026-03-20", "AAPL") + assert loaded == data def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports): """Ticker symbol must be stored as uppercase in the directory path.""" - # TODO: implement - raise NotImplementedError + data = {"ticker": "aapl"} + report_store.save_analysis("2026-03-20", "aapl", data) + expected = tmp_reports / "daily" / "2026-03-20" / "AAPL" / "complete_report.json" + assert expected.exists() + # load with lowercase should still work + loaded = report_store.load_analysis("2026-03-20", "aapl") + assert loaded == data # --------------------------------------------------------------------------- @@ -63,14 +72,16 @@ def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports): def test_save_and_load_holding_review(report_store): """save_holding_review() then load_holding_review() must round-trip.""" - # TODO: implement - raise NotImplementedError + data = {"ticker": "MSFT", "verdict": "HOLD", "price_target": 420.0} + report_store.save_holding_review("2026-03-20", "MSFT", data) + loaded = report_store.load_holding_review("2026-03-20", "MSFT") + assert loaded == data def test_load_holding_review_returns_none_for_missing(report_store): """load_holding_review() must return None when the file does not exist.""" - # TODO: implement - raise NotImplementedError + result = report_store.load_holding_review("1900-01-01", "ZZZZ") + assert result is None # --------------------------------------------------------------------------- @@ -80,8 +91,10 @@ def test_load_holding_review_returns_none_for_missing(report_store): def test_save_and_load_risk_metrics(report_store): """save_risk_metrics() then load_risk_metrics() must round-trip.""" - # TODO: implement - raise NotImplementedError + data = {"sharpe": 1.35, "sortino": 1.8, "max_drawdown": -0.12} + report_store.save_risk_metrics("2026-03-20", "pid-123", data) + loaded = report_store.load_risk_metrics("2026-03-20", "pid-123") + assert loaded == data # --------------------------------------------------------------------------- @@ -91,37 +104,46 @@ def test_save_and_load_risk_metrics(report_store): def test_save_and_load_pm_decision_json(report_store): """save_pm_decision() then load_pm_decision() must round-trip JSON.""" - # TODO: implement - # decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]} - # report_store.save_pm_decision("2026-03-20", "pid-123", decision) - # loaded = report_store.load_pm_decision("2026-03-20", "pid-123") - # assert loaded == decision - raise NotImplementedError + decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]} + report_store.save_pm_decision("2026-03-20", "pid-123", decision) + loaded = report_store.load_pm_decision("2026-03-20", "pid-123") + assert loaded == decision def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports): """When markdown is passed to save_pm_decision(), .md file must be written.""" - # TODO: implement - raise NotImplementedError + decision = {"sells": [], "buys": []} + md_text = "# Decision\n\nHold everything." + report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=md_text) + md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md" + assert md_path.exists() + assert md_path.read_text(encoding="utf-8") == md_text def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports): """When markdown=None, no .md file should be written.""" - # TODO: implement - raise NotImplementedError + decision = {"sells": [], "buys": []} + report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=None) + md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md" + assert not md_path.exists() def test_load_pm_decision_returns_none_for_missing(report_store): """load_pm_decision() must return None when the file does not exist.""" - # TODO: implement - raise NotImplementedError + result = report_store.load_pm_decision("1900-01-01", "pid-none") + assert result is None def test_list_pm_decisions(report_store): """list_pm_decisions() must return all saved decision paths, newest first.""" - # TODO: implement - # Save decisions for multiple dates, verify order - raise NotImplementedError + dates = ["2026-03-18", "2026-03-19", "2026-03-20"] + for d in dates: + report_store.save_pm_decision(d, "pid-abc", {"date": d}) + paths = report_store.list_pm_decisions("pid-abc") + assert len(paths) == 3 + # Sorted newest first by ISO date string ordering + date_parts = [p.parent.parent.name for p in paths] + assert date_parts == sorted(dates, reverse=True) # --------------------------------------------------------------------------- @@ -131,15 +153,24 @@ def test_list_pm_decisions(report_store): def test_directories_created_on_write(report_store, tmp_reports): """Directories must be created automatically on first write.""" - # TODO: implement - # assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists() - # report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2}) - # assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir() - raise NotImplementedError + target_dir = tmp_reports / "daily" / "2026-03-20" / "portfolio" + assert not target_dir.exists() + report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2}) + assert target_dir.is_dir() def test_json_formatted_with_indent(report_store, tmp_reports): """Written JSON files must use indent=2 for human readability.""" - # TODO: implement - # Write a file, read the raw bytes, verify indentation - raise NotImplementedError + data = {"key": "value", "nested": {"a": 1}} + path = report_store.save_scan("2026-03-20", data) + raw = path.read_text(encoding="utf-8") + # indent=2 means lines like ' "key": ...' + assert ' "key"' in raw + + +def test_read_json_raises_on_corrupt_file(report_store, tmp_reports): + """_read_json must raise ReportStoreError for corrupt JSON.""" + corrupt = tmp_reports / "corrupt.json" + corrupt.write_text("not valid json{{{", encoding="utf-8") + with pytest.raises(ReportStoreError): + report_store._read_json(corrupt) diff --git a/tradingagents/portfolio/migrations/001_initial_schema.sql b/tradingagents/portfolio/migrations/001_initial_schema.sql index 42724b7a..6260d9d7 100644 --- a/tradingagents/portfolio/migrations/001_initial_schema.sql +++ b/tradingagents/portfolio/migrations/001_initial_schema.sql @@ -65,16 +65,14 @@ CREATE TABLE IF NOT EXISTS trades ( trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, ticker TEXT NOT NULL, - action TEXT NOT NULL, + action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')), shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), price NUMERIC(18,4) NOT NULL CHECK (price > 0), total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0), trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), rationale TEXT, -- PM agent rationale for this trade signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent' - metadata JSONB NOT NULL DEFAULT '{}', - - CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL')) + metadata JSONB NOT NULL DEFAULT '{}' ); COMMENT ON TABLE trades IS diff --git a/tradingagents/portfolio/models.py b/tradingagents/portfolio/models.py index 9a61fe9b..0cbfd442 100644 --- a/tradingagents/portfolio/models.py +++ b/tradingagents/portfolio/models.py @@ -6,11 +6,26 @@ All models are Python ``dataclass`` types with: - ``from_dict()`` class method for deserialisation - ``enrich()`` for attaching runtime-computed fields +**float vs Decimal** — monetary fields (cash, price, shares, etc.) use plain +``float`` throughout. Rationale: + +1. This is **mock trading only** — no real money changes hands. The cost of a + subtle floating-point rounding error is zero. +2. All upstream data sources (yfinance, Alpha Vantage, Finnhub) return ``float`` + already. Converting to ``Decimal`` at the boundary would require a custom + JSON encoder *and* decoder everywhere, for no practical gain. +3. ``json.dumps`` serialises ``float`` natively; ``Decimal`` raises + ``TypeError`` without a custom encoder. +4. If this ever becomes real-money trading, replace ``float`` with + ``decimal.Decimal`` and add a ``DecimalEncoder`` — the interface is + identical and the change is localised to this file. + See ``docs/portfolio/02_data_models.md`` for full field specifications. """ from __future__ import annotations +import json from dataclasses import dataclass, field from typing import Any @@ -48,8 +63,17 @@ class Portfolio: Runtime-computed fields are excluded. """ - # TODO: implement - raise NotImplementedError + return { + "portfolio_id": self.portfolio_id, + "name": self.name, + "cash": self.cash, + "initial_cash": self.initial_cash, + "currency": self.currency, + "created_at": self.created_at, + "updated_at": self.updated_at, + "report_path": self.report_path, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "Portfolio": @@ -57,8 +81,17 @@ class Portfolio: Missing optional fields default gracefully. Extra keys are ignored. """ - # TODO: implement - raise NotImplementedError + return cls( + portfolio_id=data["portfolio_id"], + name=data["name"], + cash=float(data["cash"]), + initial_cash=float(data["initial_cash"]), + currency=data.get("currency", "USD"), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + report_path=data.get("report_path"), + metadata=data.get("metadata") or {}, + ) def enrich(self, holdings: list["Holding"]) -> "Portfolio": """Compute total_value, equity_value, cash_pct from holdings. @@ -69,8 +102,12 @@ class Portfolio: holdings: List of Holding objects with current_value populated (i.e., ``holding.enrich()`` already called). """ - # TODO: implement - raise NotImplementedError + self.equity_value = sum( + h.current_value for h in holdings if h.current_value is not None + ) + self.total_value = self.cash + self.equity_value + self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0 + return self # --------------------------------------------------------------------------- @@ -106,14 +143,32 @@ class Holding: def to_dict(self) -> dict[str, Any]: """Serialise stored fields only (runtime-computed fields excluded).""" - # TODO: implement - raise NotImplementedError + return { + "holding_id": self.holding_id, + "portfolio_id": self.portfolio_id, + "ticker": self.ticker, + "shares": self.shares, + "avg_cost": self.avg_cost, + "sector": self.sector, + "industry": self.industry, + "created_at": self.created_at, + "updated_at": self.updated_at, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "Holding": """Deserialise from a DB row or JSON dict.""" - # TODO: implement - raise NotImplementedError + return cls( + holding_id=data["holding_id"], + portfolio_id=data["portfolio_id"], + ticker=data["ticker"], + shares=float(data["shares"]), + avg_cost=float(data["avg_cost"]), + sector=data.get("sector"), + industry=data.get("industry"), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + ) def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding": """Populate runtime-computed fields in-place and return self. @@ -129,8 +184,17 @@ class Holding: current_price: Latest market price for this ticker. portfolio_total_value: Total portfolio value (cash + equity). """ - # TODO: implement - raise NotImplementedError + self.current_price = current_price + self.current_value = current_price * self.shares + self.cost_basis = self.avg_cost * self.shares + self.unrealized_pnl = self.current_value - self.cost_basis + self.unrealized_pnl_pct = ( + self.unrealized_pnl / self.cost_basis if self.cost_basis != 0.0 else 0.0 + ) + self.weight = ( + self.current_value / portfolio_total_value if portfolio_total_value != 0.0 else 0.0 + ) + return self # --------------------------------------------------------------------------- @@ -159,14 +223,36 @@ class Trade: def to_dict(self) -> dict[str, Any]: """Serialise all fields.""" - # TODO: implement - raise NotImplementedError + return { + "trade_id": self.trade_id, + "portfolio_id": self.portfolio_id, + "ticker": self.ticker, + "action": self.action, + "shares": self.shares, + "price": self.price, + "total_value": self.total_value, + "trade_date": self.trade_date, + "rationale": self.rationale, + "signal_source": self.signal_source, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "Trade": """Deserialise from a DB row or JSON dict.""" - # TODO: implement - raise NotImplementedError + return cls( + trade_id=data["trade_id"], + portfolio_id=data["portfolio_id"], + ticker=data["ticker"], + action=data["action"], + shares=float(data["shares"]), + price=float(data["price"]), + total_value=float(data["total_value"]), + trade_date=data.get("trade_date", ""), + rationale=data.get("rationale"), + signal_source=data.get("signal_source"), + metadata=data.get("metadata") or {}, + ) # --------------------------------------------------------------------------- @@ -194,8 +280,17 @@ class PortfolioSnapshot: def to_dict(self) -> dict[str, Any]: """Serialise all fields. ``holdings_snapshot`` is already a list[dict].""" - # TODO: implement - raise NotImplementedError + return { + "snapshot_id": self.snapshot_id, + "portfolio_id": self.portfolio_id, + "snapshot_date": self.snapshot_date, + "total_value": self.total_value, + "cash": self.cash, + "equity_value": self.equity_value, + "num_positions": self.num_positions, + "holdings_snapshot": self.holdings_snapshot, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot": @@ -203,5 +298,17 @@ class PortfolioSnapshot: ``holdings_snapshot`` is parsed from a JSON string when needed. """ - # TODO: implement - raise NotImplementedError + holdings_snapshot = data.get("holdings_snapshot", []) + if isinstance(holdings_snapshot, str): + holdings_snapshot = json.loads(holdings_snapshot) + return cls( + snapshot_id=data["snapshot_id"], + portfolio_id=data["portfolio_id"], + snapshot_date=data["snapshot_date"], + total_value=float(data["total_value"]), + cash=float(data["cash"]), + equity_value=float(data["equity_value"]), + num_positions=int(data["num_positions"]), + holdings_snapshot=holdings_snapshot, + metadata=data.get("metadata") or {}, + ) diff --git a/tradingagents/portfolio/report_store.py b/tradingagents/portfolio/report_store.py index 5ed8594a..2d641693 100644 --- a/tradingagents/portfolio/report_store.py +++ b/tradingagents/portfolio/report_store.py @@ -28,9 +28,12 @@ Usage:: from __future__ import annotations +import json from pathlib import Path from typing import Any +from tradingagents.portfolio.exceptions import ReportStoreError + class ReportStore: """Filesystem document store for all portfolio-related reports. @@ -48,8 +51,7 @@ class ReportStore: Override via the ``PORTFOLIO_DATA_DIR`` env var or ``get_portfolio_config()["data_dir"]``. """ - # TODO: implement — store Path(base_dir), resolve as needed - raise NotImplementedError + self._base_dir = Path(base_dir) # ------------------------------------------------------------------ # Internal helpers @@ -60,8 +62,7 @@ class ReportStore: Path: ``{base_dir}/daily/{date}/portfolio/`` """ - # TODO: implement - raise NotImplementedError + return self._base_dir / "daily" / date / "portfolio" def _write_json(self, path: Path, data: dict[str, Any]) -> Path: """Write a dict to a JSON file, creating parent directories as needed. @@ -76,8 +77,12 @@ class ReportStore: Raises: ReportStoreError: On filesystem write failure. """ - # TODO: implement - raise NotImplementedError + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + return path + except OSError as exc: + raise ReportStoreError(f"Failed to write {path}: {exc}") from exc def _read_json(self, path: Path) -> dict[str, Any] | None: """Read a JSON file, returning None if the file does not exist. @@ -85,8 +90,12 @@ class ReportStore: Raises: ReportStoreError: On JSON parse error (file exists but is corrupt). """ - # TODO: implement - raise NotImplementedError + if not path.exists(): + return None + try: + return json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ReportStoreError(f"Corrupt JSON at {path}: {exc}") from exc # ------------------------------------------------------------------ # Macro Scan @@ -104,13 +113,13 @@ class ReportStore: Returns: Path of the written file. """ - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json" + return self._write_json(path, data) def load_scan(self, date: str) -> dict[str, Any] | None: """Load macro scan summary. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json" + return self._read_json(path) # ------------------------------------------------------------------ # Per-Ticker Analysis @@ -126,13 +135,13 @@ class ReportStore: ticker: Ticker symbol (stored as uppercase). data: Analysis output dict. """ - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json" + return self._write_json(path, data) def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None: """Load per-ticker analysis JSON. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json" + return self._read_json(path) # ------------------------------------------------------------------ # Holding Reviews @@ -153,13 +162,13 @@ class ReportStore: ticker: Ticker symbol (stored as uppercase). data: HoldingReviewerAgent output dict. """ - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json" + return self._write_json(path, data) def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None: """Load holding review output. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json" + return self._read_json(path) # ------------------------------------------------------------------ # Risk Metrics @@ -180,8 +189,8 @@ class ReportStore: portfolio_id: UUID of the target portfolio. data: Risk metrics dict (Sharpe, Sortino, VaR, etc.). """ - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json" + return self._write_json(path, data) def load_risk_metrics( self, @@ -189,8 +198,8 @@ class ReportStore: portfolio_id: str, ) -> dict[str, Any] | None: """Load risk metrics. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json" + return self._read_json(path) # ------------------------------------------------------------------ # PM Decisions @@ -218,8 +227,15 @@ class ReportStore: Returns: Path of the written JSON file. """ - # TODO: implement - raise NotImplementedError + json_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json" + self._write_json(json_path, data) + if markdown is not None: + md_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.md" + try: + md_path.write_text(markdown, encoding="utf-8") + except OSError as exc: + raise ReportStoreError(f"Failed to write {md_path}: {exc}") from exc + return json_path def load_pm_decision( self, @@ -227,8 +243,8 @@ class ReportStore: portfolio_id: str, ) -> dict[str, Any] | None: """Load PM decision JSON. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json" + return self._read_json(path) def list_pm_decisions(self, portfolio_id: str) -> list[Path]: """Return all saved PM decision JSON paths for portfolio_id, newest first. @@ -241,5 +257,5 @@ class ReportStore: Returns: Sorted list of Path objects, newest date first. """ - # TODO: implement - raise NotImplementedError + pattern = f"daily/*/portfolio/{portfolio_id}_pm_decision.json" + return sorted(self._base_dir.glob(pattern), reverse=True) From a17e5f3707a59f979c2e8cc749137be851f5209b Mon Sep 17 00:00:00 2001 From: Ahmet Guzererler Date: Fri, 20 Mar 2026 14:06:50 +0100 Subject: [PATCH 5/5] =?UTF-8?q?feat:=20complete=20portfolio=20data=20found?= =?UTF-8?q?ation=20=E2=80=94=20psycopg2=20client,=20repository,=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace supabase-py stubs with working psycopg2 implementation using Supabase pooler connection string. Implement full business logic in repository (avg cost basis, cash accounting, trade recording, snapshots). Add 12 unit tests + 4 integration tests (51 total portfolio tests pass). Fix cash_pct bug in models.py, update docs for psycopg2 + pooler pattern. Co-Authored-By: Claude Opus 4.6 --- docs/agent/CURRENT_STATE.md | 15 +- docs/portfolio/00_overview.md | 34 +- docs/portfolio/01_phase1_plan.md | 19 +- tests/portfolio/conftest.py | 20 +- tests/portfolio/test_repository.py | 320 +++++++++++--- tradingagents/portfolio/config.py | 107 +++-- tradingagents/portfolio/models.py | 2 +- tradingagents/portfolio/repository.py | 330 +++++++-------- tradingagents/portfolio/supabase_client.py | 467 +++++++++++++-------- 9 files changed, 816 insertions(+), 498 deletions(-) diff --git a/docs/agent/CURRENT_STATE.md b/docs/agent/CURRENT_STATE.md index b1155699..93555e26 100644 --- a/docs/agent/CURRENT_STATE.md +++ b/docs/agent/CURRENT_STATE.md @@ -1,20 +1,21 @@ # Current Milestone -Daily digest consolidation and Google NotebookLM sync shipped (PR open: `feat/daily-digest-notebooklm`). All analyses now append to a single `daily_digest.md` per day and auto-upload to NotebookLM via `nlm` CLI. Next: PR review and merge. +Portfolio Manager Phase 1 (data foundation) complete and merged. All 4 Supabase tables live, 51 tests passing (including integration tests against live DB). # Recent Progress +- **PR #32 merged**: Portfolio Manager data foundation — models, SQL schema, module scaffolding + - `tradingagents/portfolio/` — full module: models, config, exceptions, supabase_client (psycopg2), report_store, repository + - `migrations/001_initial_schema.sql` — 4 tables (portfolios, holdings, trades, snapshots) with constraints, indexes, triggers + - `tests/portfolio/` — 51 tests: 20 model, 15 report_store, 12 repository unit, 4 integration + - Uses `psycopg2` direct PostgreSQL via Supabase pooler (`aws-1-eu-west-1.pooler.supabase.com:6543`) + - Business logic: avg cost basis, cash accounting, trade recording, snapshots - **PR #22 merged**: Unified report paths, structured observability logging, memory system update - **feat/daily-digest-notebooklm** (shipped): Daily digest consolidation + NotebookLM source sync - - `tradingagents/daily_digest.py` — `append_to_digest()` appends timestamped entries to `reports/daily/{date}/daily_digest.md` - - `tradingagents/notebook_sync.py` — `sync_to_notebooklm()` deletes existing "Daily Trading Digest" source then uploads new content via `nlm source add --text --wait`. - - `tradingagents/report_paths.py` — added `get_digest_path(date)` - - `cli/main.py` — `analyze` and `scan` commands both call digest + sync after each run - - `.env.example` — fixed consistency, removed duplicates, aligned with `NOTEBOOKLM_ID` -- **Verification**: 220+ offline tests passing + 5 new unit tests for `notebook_sync.py` + live integration test passed. # In Progress +- Portfolio Manager Phase 2: Holding Reviewer Agent (next) - Refinement of macro scan synthesis prompts (ongoing) # Active Blockers diff --git a/docs/portfolio/00_overview.md b/docs/portfolio/00_overview.md index 741a2605..17753d2c 100644 --- a/docs/portfolio/00_overview.md +++ b/docs/portfolio/00_overview.md @@ -72,28 +72,26 @@ investment portfolio end-to-end. It performs the following actions in sequence: The `report_path` column in the `portfolios` table points to the daily portfolio subdirectory on disk: `reports/daily/{date}/portfolio/`. -### Data Access Layer: raw `supabase-py` (no ORM) +### Data Access Layer: raw `psycopg2` (no ORM) -The Python code talks to Supabase through the raw `supabase-py` client — **no -ORM** (Prisma, SQLAlchemy, etc.) is used. +The Python code talks to Supabase PostgreSQL directly via `psycopg2` using the +**pooler connection string** (`SUPABASE_CONNECTION_STRING`). No ORM (Prisma, +SQLAlchemy) and no `supabase-py` REST client is used. -**Why not Prisma?** -- `prisma-client-py` requires a Node.js runtime for code generation — an - extra non-Python dependency in a Python-only project. -- Prisma's `prisma migrate` conflicts with Supabase's own SQL migration tooling - (we use `.sql` files in `tradingagents/portfolio/migrations/`). -- 4 tables with straightforward CRUD don't benefit from a code-generated ORM. +**Why `psycopg2` over `supabase-py`?** +- Direct SQL gives full control — transactions, upserts, `RETURNING *`, CTEs. +- No dependency on Supabase's PostgREST schema cache or API key types. +- `psycopg2-binary` is a single pip install with zero non-Python dependencies. +- 4 tables with straightforward CRUD don't benefit from an ORM or REST wrapper. -**Why not SQLAlchemy?** -- Supabase is accessed via PostgREST (HTTP API), not a direct TCP database - connection. SQLAlchemy is designed for direct connections and would bypass - Supabase's Row Level Security. -- Extra dependency overhead for a non-DB-heavy feature. +**Connection:** +- Uses `SUPABASE_CONNECTION_STRING` env var (pooler URI format). +- Passwords with special characters are auto-URL-encoded by `SupabaseClient._fix_dsn()`. +- Typical pooler URI: `postgresql://postgres.:@aws-1-.pooler.supabase.com:6543/postgres` -**`supabase-py` is sufficient because:** -- Its builder-pattern API (`client.table("holdings").select("*").eq(...)`) - covers all needed queries cleanly. -- Our own dataclasses handle type-safety via `to_dict()` / `from_dict()`. +**Why not Prisma / SQLAlchemy?** +- Prisma requires Node.js runtime — extra non-Python dependency. +- SQLAlchemy adds dependency overhead for 4 simple tables. - Plain SQL migration files are readable, versionable, and Supabase-native. > Full rationale: `docs/agent/decisions/012-portfolio-no-orm.md` diff --git a/docs/portfolio/01_phase1_plan.md b/docs/portfolio/01_phase1_plan.md index dd14c2ee..71e82629 100644 --- a/docs/portfolio/01_phase1_plan.md +++ b/docs/portfolio/01_phase1_plan.md @@ -89,14 +89,13 @@ All are optional and default to the values shown: | Env Var | Default | Description | |---------|---------|-------------| -| `SUPABASE_URL` | `""` | Supabase project URL | -| `SUPABASE_KEY` | `""` | Supabase anon/service key | -| `PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports | -| `PM_MAX_POSITIONS` | `15` | Max number of open positions | -| `PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight | -| `PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight | -| `PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve | -| `PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) | +| `SUPABASE_CONNECTION_STRING` | `""` | PostgreSQL pooler connection URI | +| `TRADINGAGENTS_PORTFOLIO_DATA_DIR` | `"reports"` | Root dir for filesystem reports | +| `TRADINGAGENTS_PM_MAX_POSITIONS` | `15` | Max number of open positions | +| `TRADINGAGENTS_PM_MAX_POSITION_PCT` | `0.15` | Max single-position weight | +| `TRADINGAGENTS_PM_MAX_SECTOR_PCT` | `0.35` | Max sector weight | +| `TRADINGAGENTS_PM_MIN_CASH_PCT` | `0.05` | Minimum cash reserve | +| `TRADINGAGENTS_PM_DEFAULT_BUDGET` | `100000.0` | Default starting cash (USD) | ### Acceptance Criteria @@ -172,7 +171,7 @@ Methods that query a single row raise `PortfolioNotFoundError` when no row is fo - Singleton — only one Supabase connection per process - All public methods fully type-annotated -- Supabase integration tests auto-skip when `SUPABASE_URL` is unset +- Supabase integration tests auto-skip when `SUPABASE_CONNECTION_STRING` is unset --- @@ -328,7 +327,7 @@ require PG functions — deferred to Phase 3+). ### Coverage Target 90 %+ for `models.py` and `report_store.py`. -Integration tests (`test_repository.py`) auto-skip when Supabase is unavailable. +Integration tests (`test_repository.py`) auto-skip when `SUPABASE_CONNECTION_STRING` is unset. --- diff --git a/tests/portfolio/conftest.py b/tests/portfolio/conftest.py index 073d318d..9a7d61f8 100644 --- a/tests/portfolio/conftest.py +++ b/tests/portfolio/conftest.py @@ -2,16 +2,16 @@ Fixtures provided: -- ``tmp_reports`` — temporary directory used as ReportStore base_dir -- ``sample_portfolio`` — a Portfolio instance for testing (not persisted) -- ``sample_holding`` — a Holding instance for testing (not persisted) -- ``sample_trade`` — a Trade instance for testing (not persisted) -- ``sample_snapshot`` — a PortfolioSnapshot instance for testing -- ``report_store`` — a ReportStore instance backed by tmp_reports -- ``mock_supabase_client`` — MagicMock of SupabaseClient for unit tests +- ``tmp_reports`` -- temporary directory used as ReportStore base_dir +- ``sample_portfolio`` -- a Portfolio instance for testing (not persisted) +- ``sample_holding`` -- a Holding instance for testing (not persisted) +- ``sample_trade`` -- a Trade instance for testing (not persisted) +- ``sample_snapshot`` -- a PortfolioSnapshot instance for testing +- ``report_store`` -- a ReportStore instance backed by tmp_reports +- ``mock_supabase_client`` -- MagicMock of SupabaseClient for unit tests Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when -``SUPABASE_URL`` is not set in the environment. +``SUPABASE_CONNECTION_STRING`` is not set in the environment. """ from __future__ import annotations @@ -36,8 +36,8 @@ from tradingagents.portfolio.supabase_client import SupabaseClient # --------------------------------------------------------------------------- requires_supabase = pytest.mark.skipif( - not os.getenv("SUPABASE_URL"), - reason="SUPABASE_URL not set — skipping Supabase integration tests", + not os.getenv("SUPABASE_CONNECTION_STRING"), + reason="SUPABASE_CONNECTION_STRING not set -- skipping integration tests", ) diff --git a/tests/portfolio/test_repository.py b/tests/portfolio/test_repository.py index 3658ab26..958d93b9 100644 --- a/tests/portfolio/test_repository.py +++ b/tests/portfolio/test_repository.py @@ -1,29 +1,59 @@ """Tests for tradingagents/portfolio/repository.py. -Tests the PortfolioRepository façade — business logic for holdings management, -cash accounting, avg-cost-basis updates, and snapshot creation. - -Supabase integration tests are automatically skipped when ``SUPABASE_URL`` is -not set in the environment (use the ``requires_supabase`` fixture marker). - Unit tests use ``mock_supabase_client`` to avoid DB access. +Integration tests auto-skip when ``SUPABASE_CONNECTION_STRING`` is unset. Run (unit tests only):: pytest tests/portfolio/test_repository.py -v -k "not integration" - -Run (with Supabase):: - - SUPABASE_URL=... SUPABASE_KEY=... pytest tests/portfolio/test_repository.py -v """ from __future__ import annotations +from unittest.mock import MagicMock, call + import pytest +from tradingagents.portfolio.exceptions import ( + HoldingNotFoundError, + InsufficientCashError, + InsufficientSharesError, +) +from tradingagents.portfolio.models import Holding, Portfolio, Trade +from tradingagents.portfolio.repository import PortfolioRepository from tests.portfolio.conftest import requires_supabase +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_repo(mock_client, report_store): + """Build a PortfolioRepository with mock client and real report store.""" + return PortfolioRepository( + client=mock_client, + store=report_store, + config={"data_dir": "reports", "max_positions": 15, + "max_position_pct": 0.15, "max_sector_pct": 0.35, + "min_cash_pct": 0.05, "default_budget": 100_000.0, + "supabase_connection_string": ""}, + ) + + +def _mock_portfolio(pid="pid-1", cash=10_000.0): + return Portfolio( + portfolio_id=pid, name="Test", cash=cash, + initial_cash=100_000.0, currency="USD", + ) + + +def _mock_holding(pid="pid-1", ticker="AAPL", shares=50.0, avg_cost=190.0): + return Holding( + holding_id="hid-1", portfolio_id=pid, ticker=ticker, + shares=shares, avg_cost=avg_cost, + ) + + # --------------------------------------------------------------------------- # add_holding — new position # --------------------------------------------------------------------------- @@ -31,12 +61,19 @@ from tests.portfolio.conftest import requires_supabase def test_add_holding_new_position(mock_supabase_client, report_store): """add_holding() on a ticker not yet held must create a new Holding.""" - # TODO: implement - # repo = PortfolioRepository(client=mock_supabase_client, store=report_store) - # mock portfolio with enough cash - # repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0) - # assert mock_supabase_client.upsert_holding.called - raise NotImplementedError + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0) + mock_supabase_client.get_holding.return_value = None + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + holding = repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) + + assert holding.ticker == "AAPL" + assert holding.shares == 10 + assert holding.avg_cost == 200.0 + assert mock_supabase_client.upsert_holding.called # --------------------------------------------------------------------------- @@ -45,16 +82,21 @@ def test_add_holding_new_position(mock_supabase_client, report_store): def test_add_holding_updates_avg_cost(mock_supabase_client, report_store): - """add_holding() on an existing position must update avg_cost correctly. + """add_holding() on existing position must update avg_cost correctly.""" + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0) + existing = _mock_holding(shares=50.0, avg_cost=190.0) + mock_supabase_client.get_holding.return_value = existing + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t - Formula: new_avg_cost = (old_shares * old_avg_cost + new_shares * price) - / (old_shares + new_shares) - """ - # TODO: implement - # existing holding: 50 shares @ 190.0 - # buy 25 more @ 200.0 - # expected avg_cost = (50*190 + 25*200) / 75 = 193.33... - raise NotImplementedError + repo = _make_repo(mock_supabase_client, report_store) + holding = repo.add_holding("pid-1", "AAPL", shares=25, price=200.0) + + # expected: (50*190 + 25*200) / 75 = 193.333... + expected_avg = (50 * 190.0 + 25 * 200.0) / 75 + assert holding.shares == 75 + assert holding.avg_cost == pytest.approx(expected_avg) # --------------------------------------------------------------------------- @@ -64,11 +106,11 @@ def test_add_holding_updates_avg_cost(mock_supabase_client, report_store): def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store): """add_holding() must raise InsufficientCashError when cash < shares * price.""" - # TODO: implement - # portfolio with cash=500.0, try to buy 10 shares @ 200.0 (cost=2000) - # with pytest.raises(InsufficientCashError): - # repo.add_holding(portfolio_id, "AAPL", shares=10, price=200.0) - raise NotImplementedError + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=500.0) + + repo = _make_repo(mock_supabase_client, report_store) + with pytest.raises(InsufficientCashError): + repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) # --------------------------------------------------------------------------- @@ -78,11 +120,17 @@ def test_add_holding_raises_insufficient_cash(mock_supabase_client, report_store def test_remove_holding_full_position(mock_supabase_client, report_store): """remove_holding() selling all shares must delete the holding row.""" - # TODO: implement - # holding: 50 shares - # sell 50 shares → holding deleted, cash credited - # assert mock_supabase_client.delete_holding.called - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0) + mock_supabase_client.delete_holding.return_value = None + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + result = repo.remove_holding("pid-1", "AAPL", shares=50.0, price=200.0) + + assert result is None + assert mock_supabase_client.delete_holding.called # --------------------------------------------------------------------------- @@ -92,10 +140,17 @@ def test_remove_holding_full_position(mock_supabase_client, report_store): def test_remove_holding_partial_position(mock_supabase_client, report_store): """remove_holding() selling a subset must reduce shares, not delete.""" - # TODO: implement - # holding: 50 shares - # sell 20 → holding.shares == 30, avg_cost unchanged - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0) + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + result = repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) + + assert result is not None + assert result.shares == 30.0 # --------------------------------------------------------------------------- @@ -105,16 +160,20 @@ def test_remove_holding_partial_position(mock_supabase_client, report_store): def test_remove_holding_raises_insufficient_shares(mock_supabase_client, report_store): """remove_holding() must raise InsufficientSharesError when shares > held.""" - # TODO: implement - # holding: 10 shares - # try sell 20 → InsufficientSharesError - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=10.0) + + repo = _make_repo(mock_supabase_client, report_store) + with pytest.raises(InsufficientSharesError): + repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report_store): """remove_holding() must raise HoldingNotFoundError for unknown tickers.""" - # TODO: implement - raise NotImplementedError + mock_supabase_client.get_holding.return_value = None + + repo = _make_repo(mock_supabase_client, report_store) + with pytest.raises(HoldingNotFoundError): + repo.remove_holding("pid-1", "ZZZZ", shares=10.0, price=100.0) # --------------------------------------------------------------------------- @@ -124,15 +183,35 @@ def test_remove_holding_raises_when_ticker_not_held(mock_supabase_client, report def test_add_holding_deducts_cash(mock_supabase_client, report_store): """add_holding() must reduce portfolio.cash by shares * price.""" - # TODO: implement - # portfolio.cash = 10_000, buy 10 @ 200 → cash should be 8_000 - raise NotImplementedError + portfolio = _mock_portfolio(cash=10_000.0) + mock_supabase_client.get_portfolio.return_value = portfolio + mock_supabase_client.get_holding.return_value = None + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) + + # Check the portfolio passed to update_portfolio had cash deducted + updated = mock_supabase_client.update_portfolio.call_args[0][0] + assert updated.cash == pytest.approx(8_000.0) def test_remove_holding_credits_cash(mock_supabase_client, report_store): """remove_holding() must increase portfolio.cash by shares * price.""" - # TODO: implement - raise NotImplementedError + portfolio = _mock_portfolio(cash=5_000.0) + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = portfolio + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) + + updated = mock_supabase_client.update_portfolio.call_args[0][0] + assert updated.cash == pytest.approx(9_000.0) # --------------------------------------------------------------------------- @@ -142,14 +221,37 @@ def test_remove_holding_credits_cash(mock_supabase_client, report_store): def test_add_holding_records_buy_trade(mock_supabase_client, report_store): """add_holding() must call client.record_trade() with action='BUY'.""" - # TODO: implement - raise NotImplementedError + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=10_000.0) + mock_supabase_client.get_holding.return_value = None + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.add_holding("pid-1", "AAPL", shares=10, price=200.0) + + trade = mock_supabase_client.record_trade.call_args[0][0] + assert trade.action == "BUY" + assert trade.ticker == "AAPL" + assert trade.shares == 10 + assert trade.total_value == pytest.approx(2_000.0) def test_remove_holding_records_sell_trade(mock_supabase_client, report_store): """remove_holding() must call client.record_trade() with action='SELL'.""" - # TODO: implement - raise NotImplementedError + mock_supabase_client.get_holding.return_value = _mock_holding(shares=50.0) + mock_supabase_client.get_portfolio.return_value = _mock_portfolio(cash=5_000.0) + mock_supabase_client.upsert_holding.side_effect = lambda h: h + mock_supabase_client.update_portfolio.side_effect = lambda p: p + mock_supabase_client.record_trade.side_effect = lambda t: t + + repo = _make_repo(mock_supabase_client, report_store) + repo.remove_holding("pid-1", "AAPL", shares=20.0, price=200.0) + + trade = mock_supabase_client.record_trade.call_args[0][0] + assert trade.action == "SELL" + assert trade.ticker == "AAPL" + assert trade.shares == 20.0 # --------------------------------------------------------------------------- @@ -159,40 +261,120 @@ def test_remove_holding_records_sell_trade(mock_supabase_client, report_store): def test_take_snapshot(mock_supabase_client, report_store): """take_snapshot() must enrich holdings and persist a PortfolioSnapshot.""" - # TODO: implement - # assert mock_supabase_client.save_snapshot.called - # snapshot.total_value == cash + equity - raise NotImplementedError + portfolio = _mock_portfolio(cash=5_000.0) + holding = _mock_holding(shares=50.0, avg_cost=190.0) + mock_supabase_client.get_portfolio.return_value = portfolio + mock_supabase_client.list_holdings.return_value = [holding] + mock_supabase_client.save_snapshot.side_effect = lambda s: s + + repo = _make_repo(mock_supabase_client, report_store) + snapshot = repo.take_snapshot("pid-1", prices={"AAPL": 200.0}) + + assert mock_supabase_client.save_snapshot.called + assert snapshot.cash == 5_000.0 + assert snapshot.num_positions == 1 + assert snapshot.total_value == pytest.approx(5_000.0 + 50.0 * 200.0) # --------------------------------------------------------------------------- -# Supabase integration tests (auto-skip without SUPABASE_URL) +# Supabase integration tests (auto-skip without SUPABASE_CONNECTION_STRING) # --------------------------------------------------------------------------- @requires_supabase def test_integration_create_and_get_portfolio(): """Integration: create a portfolio, retrieve it, verify fields match.""" - # TODO: implement - raise NotImplementedError + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Integration Test", initial_cash=50_000.0) + try: + fetched = repo.get_portfolio(portfolio.portfolio_id) + assert fetched.name == "Integration Test" + assert fetched.cash == pytest.approx(50_000.0) + assert fetched.initial_cash == pytest.approx(50_000.0) + finally: + client.delete_portfolio(portfolio.portfolio_id) @requires_supabase def test_integration_add_and_remove_holding(): - """Integration: add holding, verify DB row; remove, verify deletion.""" - # TODO: implement - raise NotImplementedError + """Integration: add holding, verify; remove, verify deletion.""" + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Hold Test", initial_cash=50_000.0) + try: + holding = repo.add_holding( + portfolio.portfolio_id, "AAPL", shares=10, price=200.0, + sector="Technology", + ) + assert holding.ticker == "AAPL" + assert holding.shares == 10 + + # Verify cash deducted + p = repo.get_portfolio(portfolio.portfolio_id) + assert p.cash == pytest.approx(48_000.0) + + # Sell all + result = repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=10, price=210.0) + assert result is None + + # Verify cash credited + p = repo.get_portfolio(portfolio.portfolio_id) + assert p.cash == pytest.approx(50_100.0) + finally: + client.delete_portfolio(portfolio.portfolio_id) @requires_supabase def test_integration_record_and_list_trades(): - """Integration: record BUY + SELL trades, list them, verify order.""" - # TODO: implement - raise NotImplementedError + """Integration: trades are recorded automatically via add/remove holding.""" + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Trade Test", initial_cash=50_000.0) + try: + repo.add_holding(portfolio.portfolio_id, "MSFT", shares=5, price=400.0) + repo.remove_holding(portfolio.portfolio_id, "MSFT", shares=5, price=410.0) + + trades = client.list_trades(portfolio.portfolio_id) + assert len(trades) == 2 + assert trades[0].action == "SELL" # newest first + assert trades[1].action == "BUY" + finally: + client.delete_portfolio(portfolio.portfolio_id) @requires_supabase def test_integration_save_and_load_snapshot(): """Integration: take snapshot, retrieve latest, verify total_value.""" - # TODO: implement - raise NotImplementedError + from tradingagents.portfolio.supabase_client import SupabaseClient + client = SupabaseClient.get_instance() + from tradingagents.portfolio.report_store import ReportStore + store = ReportStore() + + repo = PortfolioRepository(client=client, store=store) + portfolio = repo.create_portfolio("Snap Test", initial_cash=50_000.0) + try: + repo.add_holding(portfolio.portfolio_id, "AAPL", shares=10, price=200.0) + snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 210.0}) + + assert snapshot.num_positions == 1 + assert snapshot.cash == pytest.approx(48_000.0) + assert snapshot.total_value == pytest.approx(48_000.0 + 10 * 210.0) + + latest = client.get_latest_snapshot(portfolio.portfolio_id) + assert latest is not None + assert latest.snapshot_id == snapshot.snapshot_id + finally: + client.delete_portfolio(portfolio.portfolio_id) diff --git a/tradingagents/portfolio/config.py b/tradingagents/portfolio/config.py index 4ffbc403..d8293e49 100644 --- a/tradingagents/portfolio/config.py +++ b/tradingagents/portfolio/config.py @@ -1,8 +1,7 @@ """Portfolio Manager configuration. -Reads all portfolio-related settings from environment variables with sensible -defaults. Integrates with the existing ``tradingagents/default_config.py`` -pattern. +Integrates with the existing ``tradingagents/default_config.py`` pattern, +reading all portfolio settings from ``TRADINGAGENTS_`` env vars. Usage:: @@ -11,34 +10,51 @@ Usage:: cfg = get_portfolio_config() validate_config(cfg) print(cfg["max_positions"]) # 15 - -Environment variables (all optional): - - SUPABASE_URL Supabase project URL (default: "") - SUPABASE_KEY Supabase anon/service role key (default: "") - PORTFOLIO_DATA_DIR Root dir for filesystem reports (default: "reports") - PM_MAX_POSITIONS Max open positions (default: 15) - PM_MAX_POSITION_PCT Max single-position weight 0–1 (default: 0.15) - PM_MAX_SECTOR_PCT Max sector weight 0–1 (default: 0.35) - PM_MIN_CASH_PCT Minimum cash reserve 0–1 (default: 0.05) - PM_DEFAULT_BUDGET Default starting cash in USD (default: 100000.0) """ from __future__ import annotations import os -# --------------------------------------------------------------------------- -# Defaults -# --------------------------------------------------------------------------- +from dotenv import load_dotenv + +load_dotenv() + + +def _env(key: str, default=None): + """Read ``TRADINGAGENTS_`` from the environment. + + Matches the convention in ``tradingagents/default_config.py``. + """ + val = os.getenv(f"TRADINGAGENTS_{key.upper()}") + if not val: + return default + return val + + +def _env_float(key: str, default=None): + val = _env(key) + if val is None: + return default + try: + return float(val) + except (ValueError, TypeError): + return default + + +def _env_int(key: str, default=None): + val = _env(key) + if val is None: + return default + try: + return int(val) + except (ValueError, TypeError): + return default + PORTFOLIO_CONFIG: dict = { - # Supabase connection - "supabase_url": "", - "supabase_key": "", - # Filesystem report root (matches report_paths.py REPORTS_ROOT) - "data_dir": "reports", - # PM constraint defaults + "supabase_connection_string": os.getenv("SUPABASE_CONNECTION_STRING", ""), + "data_dir": _env("PORTFOLIO_DATA_DIR", "reports"), "max_positions": 15, "max_position_pct": 0.15, "max_sector_pct": 0.35, @@ -47,41 +63,44 @@ PORTFOLIO_CONFIG: dict = { } -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - def get_portfolio_config() -> dict: """Return the merged portfolio config (defaults overridden by env vars). - Reads ``SUPABASE_URL``, ``SUPABASE_KEY``, ``PORTFOLIO_DATA_DIR``, - ``PM_MAX_POSITIONS``, ``PM_MAX_POSITION_PCT``, ``PM_MAX_SECTOR_PCT``, - ``PM_MIN_CASH_PCT``, and ``PM_DEFAULT_BUDGET`` from the environment. - Returns: A dict with all portfolio configuration keys. """ - # TODO: implement — merge PORTFOLIO_CONFIG with env var overrides - raise NotImplementedError + cfg = dict(PORTFOLIO_CONFIG) + cfg["supabase_connection_string"] = os.getenv("SUPABASE_CONNECTION_STRING", cfg["supabase_connection_string"]) + cfg["data_dir"] = _env("PORTFOLIO_DATA_DIR", cfg["data_dir"]) + cfg["max_positions"] = _env_int("PM_MAX_POSITIONS", cfg["max_positions"]) + cfg["max_position_pct"] = _env_float("PM_MAX_POSITION_PCT", cfg["max_position_pct"]) + cfg["max_sector_pct"] = _env_float("PM_MAX_SECTOR_PCT", cfg["max_sector_pct"]) + cfg["min_cash_pct"] = _env_float("PM_MIN_CASH_PCT", cfg["min_cash_pct"]) + cfg["default_budget"] = _env_float("PM_DEFAULT_BUDGET", cfg["default_budget"]) + return cfg def validate_config(cfg: dict) -> None: """Validate a portfolio config dict, raising ValueError on invalid values. - Checks: - - ``max_positions >= 1`` - - ``0 < max_position_pct <= 1`` - - ``0 < max_sector_pct <= 1`` - - ``0 <= min_cash_pct < 1`` - - ``default_budget > 0`` - - ``min_cash_pct + max_position_pct <= 1`` (can always meet both constraints) - Args: cfg: Config dict as returned by ``get_portfolio_config()``. Raises: ValueError: With a descriptive message on the first failed check. """ - # TODO: implement - raise NotImplementedError + if cfg["max_positions"] < 1: + raise ValueError(f"max_positions must be >= 1, got {cfg['max_positions']}") + if not (0 < cfg["max_position_pct"] <= 1.0): + raise ValueError(f"max_position_pct must be in (0, 1], got {cfg['max_position_pct']}") + if not (0 < cfg["max_sector_pct"] <= 1.0): + raise ValueError(f"max_sector_pct must be in (0, 1], got {cfg['max_sector_pct']}") + if not (0 <= cfg["min_cash_pct"] < 1.0): + raise ValueError(f"min_cash_pct must be in [0, 1), got {cfg['min_cash_pct']}") + if cfg["default_budget"] <= 0: + raise ValueError(f"default_budget must be > 0, got {cfg['default_budget']}") + if cfg["min_cash_pct"] + cfg["max_position_pct"] > 1.0: + raise ValueError( + f"min_cash_pct ({cfg['min_cash_pct']}) + max_position_pct ({cfg['max_position_pct']}) " + f"must be <= 1.0" + ) diff --git a/tradingagents/portfolio/models.py b/tradingagents/portfolio/models.py index 0cbfd442..f9989af4 100644 --- a/tradingagents/portfolio/models.py +++ b/tradingagents/portfolio/models.py @@ -106,7 +106,7 @@ class Portfolio: h.current_value for h in holdings if h.current_value is not None ) self.total_value = self.cash + self.equity_value - self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0 + self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 1.0 return self diff --git a/tradingagents/portfolio/repository.py b/tradingagents/portfolio/repository.py index 123803fd..0a556ef5 100644 --- a/tradingagents/portfolio/repository.py +++ b/tradingagents/portfolio/repository.py @@ -1,38 +1,28 @@ -"""Unified data-access façade for the Portfolio Manager. +"""Unified data-access facade for the Portfolio Manager. ``PortfolioRepository`` combines ``SupabaseClient`` (transactional data) and ``ReportStore`` (filesystem documents) into a single, business-logic-aware interface. -Callers should **only** interact with ``PortfolioRepository`` — do not use -``SupabaseClient`` or ``ReportStore`` directly from outside this package. - Usage:: from tradingagents.portfolio import PortfolioRepository repo = PortfolioRepository() - - # Create a portfolio portfolio = repo.create_portfolio("Main Portfolio", initial_cash=100_000.0) - - # Buy shares holding = repo.add_holding(portfolio.portfolio_id, "AAPL", shares=50, price=195.50) - # Sell shares - repo.remove_holding(portfolio.portfolio_id, "AAPL", shares=25, price=200.00) - - # Snapshot - snapshot = repo.take_snapshot(portfolio.portfolio_id, prices={"AAPL": 200.00}) - See ``docs/portfolio/04_repository_api.md`` for full API documentation. """ from __future__ import annotations +import uuid +from datetime import datetime, timezone from pathlib import Path from typing import Any +from tradingagents.portfolio.config import get_portfolio_config from tradingagents.portfolio.exceptions import ( HoldingNotFoundError, InsufficientCashError, @@ -49,7 +39,7 @@ from tradingagents.portfolio.supabase_client import SupabaseClient class PortfolioRepository: - """Unified façade over SupabaseClient and ReportStore. + """Unified facade over SupabaseClient and ReportStore. Implements business logic for: - Average cost basis updates on repeated buys @@ -64,16 +54,9 @@ class PortfolioRepository: store: ReportStore | None = None, config: dict[str, Any] | None = None, ) -> None: - """Initialise the repository. - - Args: - client: SupabaseClient instance. Uses ``SupabaseClient.get_instance()`` - when None. - store: ReportStore instance. Creates a default instance when None. - config: Portfolio config dict. Uses ``get_portfolio_config()`` when None. - """ - # TODO: implement — resolve defaults, store as self._client, self._store, self._cfg - raise NotImplementedError + self._cfg = config or get_portfolio_config() + self._client = client or SupabaseClient.get_instance() + self._store = store or ReportStore(base_dir=self._cfg["data_dir"]) # ------------------------------------------------------------------ # Portfolio lifecycle @@ -85,54 +68,42 @@ class PortfolioRepository: initial_cash: float, currency: str = "USD", ) -> Portfolio: - """Create a new portfolio with the given starting capital. - - Generates a UUID for ``portfolio_id``. Sets ``cash = initial_cash``. - - Args: - name: Human-readable portfolio name. - initial_cash: Starting capital in USD (or configured currency). - currency: ISO 4217 currency code. - - Returns: - Persisted Portfolio instance. - - Raises: - DuplicatePortfolioError: If the name is already in use. - ValueError: If ``initial_cash <= 0``. - """ - # TODO: implement - raise NotImplementedError + """Create a new portfolio with the given starting capital.""" + if initial_cash <= 0: + raise ValueError(f"initial_cash must be > 0, got {initial_cash}") + portfolio = Portfolio( + portfolio_id=str(uuid.uuid4()), + name=name, + cash=initial_cash, + initial_cash=initial_cash, + currency=currency, + ) + return self._client.create_portfolio(portfolio) def get_portfolio(self, portfolio_id: str) -> Portfolio: - """Fetch a portfolio by ID. - - Args: - portfolio_id: UUID of the target portfolio. - - Raises: - PortfolioNotFoundError: If not found. - """ - # TODO: implement - raise NotImplementedError + """Fetch a portfolio by ID.""" + return self._client.get_portfolio(portfolio_id) def get_portfolio_with_holdings( self, portfolio_id: str, prices: dict[str, float] | None = None, ) -> tuple[Portfolio, list[Holding]]: - """Fetch portfolio + all holdings, optionally enriched with current prices. - - Args: - portfolio_id: UUID of the target portfolio. - prices: Optional ``{ticker: current_price}`` dict. When provided, - holdings are enriched and ``Portfolio.enrich()`` is called. - - Returns: - ``(Portfolio, list[Holding])`` — enriched when prices are supplied. - """ - # TODO: implement - raise NotImplementedError + """Fetch portfolio + all holdings, optionally enriched with current prices.""" + portfolio = self._client.get_portfolio(portfolio_id) + holdings = self._client.list_holdings(portfolio_id) + if prices: + # First pass: compute equity for total_value + equity = sum( + prices.get(h.ticker, 0.0) * h.shares for h in holdings + ) + total_value = portfolio.cash + equity + # Second pass: enrich each holding with weight + for h in holdings: + if h.ticker in prices: + h.enrich(prices[h.ticker], total_value) + portfolio.enrich(holdings) + return portfolio, holdings # ------------------------------------------------------------------ # Holdings management @@ -147,37 +118,65 @@ class PortfolioRepository: sector: str | None = None, industry: str | None = None, ) -> Holding: - """Buy shares and update portfolio cash and holdings. + """Buy shares and update portfolio cash and holdings.""" + if shares <= 0: + raise ValueError(f"shares must be > 0, got {shares}") + if price <= 0: + raise ValueError(f"price must be > 0, got {price}") - Business logic: - - Raises ``InsufficientCashError`` if ``portfolio.cash < shares * price`` - - If holding already exists: updates ``avg_cost`` using weighted average - - ``portfolio.cash -= shares * price`` - - Records a BUY trade automatically + cost = shares * price + portfolio = self._client.get_portfolio(portfolio_id) - Avg cost formula:: + if portfolio.cash < cost: + raise InsufficientCashError( + f"Need ${cost:.2f} but only ${portfolio.cash:.2f} available" + ) - new_avg_cost = (old_shares * old_avg_cost + new_shares * price) - / (old_shares + new_shares) + # Check for existing holding to update avg cost + existing = self._client.get_holding(portfolio_id, ticker) + if existing: + new_total_shares = existing.shares + shares + new_avg_cost = ( + (existing.shares * existing.avg_cost + shares * price) / new_total_shares + ) + existing.shares = new_total_shares + existing.avg_cost = new_avg_cost + if sector: + existing.sector = sector + if industry: + existing.industry = industry + holding = self._client.upsert_holding(existing) + else: + holding = Holding( + holding_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + shares=shares, + avg_cost=price, + sector=sector, + industry=industry, + ) + holding = self._client.upsert_holding(holding) - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol. - shares: Number of shares to buy (must be > 0). - price: Execution price per share. - sector: Optional GICS sector name. - industry: Optional GICS industry name. + # Deduct cash + portfolio.cash -= cost + self._client.update_portfolio(portfolio) - Returns: - Updated or created Holding. + # Record trade + trade = Trade( + trade_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + action="BUY", + shares=shares, + price=price, + total_value=cost, + trade_date=datetime.now(timezone.utc).isoformat(), + signal_source="pm_agent", + ) + self._client.record_trade(trade) - Raises: - InsufficientCashError: If cash would go negative. - PortfolioNotFoundError: If portfolio_id does not exist. - ValueError: If shares <= 0 or price <= 0. - """ - # TODO: implement - raise NotImplementedError + return holding def remove_holding( self, @@ -186,33 +185,53 @@ class PortfolioRepository: shares: float, price: float, ) -> Holding | None: - """Sell shares and update portfolio cash and holdings. + """Sell shares and update portfolio cash and holdings.""" + if shares <= 0: + raise ValueError(f"shares must be > 0, got {shares}") + if price <= 0: + raise ValueError(f"price must be > 0, got {price}") - Business logic: - - Raises ``HoldingNotFoundError`` if no holding exists for ticker - - Raises ``InsufficientSharesError`` if ``holding.shares < shares`` - - If ``shares == holding.shares``: deletes the holding row, returns None - - Otherwise: decrements ``holding.shares`` (avg_cost unchanged on sell) - - ``portfolio.cash += shares * price`` - - Records a SELL trade automatically + existing = self._client.get_holding(portfolio_id, ticker) + if not existing: + raise HoldingNotFoundError( + f"No holding for {ticker} in portfolio {portfolio_id}" + ) - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol. - shares: Number of shares to sell (must be > 0). - price: Execution price per share. + if existing.shares < shares: + raise InsufficientSharesError( + f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}" + ) - Returns: - Updated Holding, or None if the position was fully closed. + proceeds = shares * price + portfolio = self._client.get_portfolio(portfolio_id) - Raises: - HoldingNotFoundError: If no holding exists for this ticker. - InsufficientSharesError: If holding.shares < shares. - PortfolioNotFoundError: If portfolio_id does not exist. - ValueError: If shares <= 0 or price <= 0. - """ - # TODO: implement - raise NotImplementedError + if existing.shares == shares: + # Full sell — delete holding + self._client.delete_holding(portfolio_id, ticker) + result = None + else: + existing.shares -= shares + result = self._client.upsert_holding(existing) + + # Credit cash + portfolio.cash += proceeds + self._client.update_portfolio(portfolio) + + # Record trade + trade = Trade( + trade_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + action="SELL", + shares=shares, + price=price, + total_value=proceeds, + trade_date=datetime.now(timezone.utc).isoformat(), + signal_source="pm_agent", + ) + self._client.record_trade(trade) + + return result # ------------------------------------------------------------------ # Snapshots @@ -223,20 +242,19 @@ class PortfolioRepository: portfolio_id: str, prices: dict[str, float], ) -> PortfolioSnapshot: - """Take an immutable snapshot of the current portfolio state. - - Fetches all holdings, enriches them with ``prices``, computes - ``total_value``, then persists via ``SupabaseClient.save_snapshot()``. - - Args: - portfolio_id: UUID of the target portfolio. - prices: ``{ticker: current_price}`` for all held tickers. - - Returns: - Persisted PortfolioSnapshot. - """ - # TODO: implement - raise NotImplementedError + """Take an immutable snapshot of the current portfolio state.""" + portfolio, holdings = self.get_portfolio_with_holdings(portfolio_id, prices) + snapshot = PortfolioSnapshot( + snapshot_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + snapshot_date=datetime.now(timezone.utc).isoformat(), + total_value=portfolio.total_value or portfolio.cash, + cash=portfolio.cash, + equity_value=portfolio.equity_value or 0.0, + num_positions=len(holdings), + holdings_snapshot=[h.to_dict() for h in holdings], + ) + return self._client.save_snapshot(snapshot) # ------------------------------------------------------------------ # Report convenience methods @@ -249,37 +267,21 @@ class PortfolioRepository: decision: dict[str, Any], markdown: str | None = None, ) -> Path: - """Save a PM agent decision and update portfolio.report_path. - - Delegates to ``ReportStore.save_pm_decision()`` then updates the - ``portfolio.report_path`` column in Supabase to point to the daily - portfolio directory. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string, e.g. ``"2026-03-20"``. - decision: PM decision dict. - markdown: Optional human-readable markdown version. - - Returns: - Path of the written JSON file. - """ - # TODO: implement - raise NotImplementedError + """Save a PM agent decision and update portfolio.report_path.""" + path = self._store.save_pm_decision(date, portfolio_id, decision, markdown) + # Update portfolio report_path + portfolio = self._client.get_portfolio(portfolio_id) + portfolio.report_path = str(self._store._portfolio_dir(date)) + self._client.update_portfolio(portfolio) + return path def load_pm_decision( self, portfolio_id: str, date: str, ) -> dict[str, Any] | None: - """Load a PM decision JSON. Returns None if not found. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string. - """ - # TODO: implement - raise NotImplementedError + """Load a PM decision JSON. Returns None if not found.""" + return self._store.load_pm_decision(date, portfolio_id) def save_risk_metrics( self, @@ -287,29 +289,13 @@ class PortfolioRepository: date: str, metrics: dict[str, Any], ) -> Path: - """Save risk computation results. Delegates to ReportStore. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string. - metrics: Risk metrics dict (Sharpe, Sortino, VaR, beta, etc.). - - Returns: - Path of the written file. - """ - # TODO: implement - raise NotImplementedError + """Save risk computation results.""" + return self._store.save_risk_metrics(date, portfolio_id, metrics) def load_risk_metrics( self, portfolio_id: str, date: str, ) -> dict[str, Any] | None: - """Load risk metrics. Returns None if not found. - - Args: - portfolio_id: UUID of the target portfolio. - date: ISO date string. - """ - # TODO: implement - raise NotImplementedError + """Load risk metrics. Returns None if not found.""" + return self._store.load_risk_metrics(date, portfolio_id) diff --git a/tradingagents/portfolio/supabase_client.py b/tradingagents/portfolio/supabase_client.py index efd87b34..c0107544 100644 --- a/tradingagents/portfolio/supabase_client.py +++ b/tradingagents/portfolio/supabase_client.py @@ -1,34 +1,30 @@ -"""Supabase database client for the Portfolio Manager. +"""PostgreSQL database client for the Portfolio Manager. -Thin wrapper around ``supabase-py`` (no ORM) that: -- Provides a singleton connection (one client per process) -- Translates Supabase / HTTP errors into domain exceptions -- Converts raw DB rows into typed model instances via ``Model.from_dict()`` - -**No ORM is used here by design** — see -``docs/agent/decisions/012-portfolio-no-orm.md`` for the full rationale. -In short: ``supabase-py``'s builder-pattern API is sufficient for 4 tables; -Prisma and SQLAlchemy add build-step / runtime complexity that isn't warranted -for this non-DB-heavy feature. +Uses ``psycopg2`` with the ``SUPABASE_CONNECTION_STRING`` env var to talk +directly to the Supabase-hosted PostgreSQL database. No ORM — see +``docs/agent/decisions/012-portfolio-no-orm.md`` for rationale. Usage:: from tradingagents.portfolio.supabase_client import SupabaseClient - from tradingagents.portfolio.models import Portfolio client = SupabaseClient.get_instance() portfolio = client.get_portfolio("some-uuid") - -Configuration (read from environment via ``get_portfolio_config()``): - SUPABASE_URL — Supabase project URL - SUPABASE_KEY — Supabase anon or service-role key """ from __future__ import annotations +import json +import uuid + +import psycopg2 +import psycopg2.extras + +from tradingagents.portfolio.config import get_portfolio_config from tradingagents.portfolio.exceptions import ( DuplicatePortfolioError, HoldingNotFoundError, + PortfolioError, PortfolioNotFoundError, ) from tradingagents.portfolio.models import ( @@ -40,167 +36,213 @@ from tradingagents.portfolio.models import ( class SupabaseClient: - """Singleton Supabase CRUD client for portfolio data. + """Singleton PostgreSQL CRUD client for portfolio data. - All public methods translate Supabase / HTTP errors into domain exceptions + All public methods translate database errors into domain exceptions and return typed model instances. - - Do not instantiate directly — use ``SupabaseClient.get_instance()``. """ - _instance: "SupabaseClient | None" = None + _instance: SupabaseClient | None = None - def __init__(self, url: str, key: str) -> None: - """Initialise the Supabase client. + def __init__(self, connection_string: str) -> None: + self._dsn = self._fix_dsn(connection_string) + self._conn = psycopg2.connect(self._dsn) + self._conn.autocommit = True - Args: - url: Supabase project URL. - key: Supabase anon or service-role key. - """ - # TODO: implement — create supabase.create_client(url, key) - raise NotImplementedError + @staticmethod + def _fix_dsn(dsn: str) -> str: + """URL-encode the password if it contains special characters.""" + from urllib.parse import quote + if "://" not in dsn: + return dsn # already key=value format + scheme, rest = dsn.split("://", 1) + at_idx = rest.rfind("@") + if at_idx == -1: + return dsn + userinfo = rest[:at_idx] + hostinfo = rest[at_idx + 1:] + colon_idx = userinfo.find(":") + if colon_idx == -1: + return dsn + user = userinfo[:colon_idx] + password = userinfo[colon_idx + 1:] + encoded = quote(password, safe="") + return f"{scheme}://{user}:{encoded}@{hostinfo}" @classmethod - def get_instance(cls) -> "SupabaseClient": - """Return the singleton instance, creating it if necessary. + def get_instance(cls) -> SupabaseClient: + """Return the singleton instance, creating it if necessary.""" + if cls._instance is None: + cfg = get_portfolio_config() + dsn = cfg["supabase_connection_string"] + if not dsn: + raise PortfolioError( + "SUPABASE_CONNECTION_STRING not configured. " + "Set it in .env or as an environment variable." + ) + cls._instance = cls(dsn) + return cls._instance - Reads SUPABASE_URL and SUPABASE_KEY from ``get_portfolio_config()``. + @classmethod + def reset_instance(cls) -> None: + """Close and reset the singleton (for testing).""" + if cls._instance is not None: + try: + cls._instance._conn.close() + except Exception: + pass + cls._instance = None - Raises: - PortfolioError: If SUPABASE_URL or SUPABASE_KEY are not configured. - """ - # TODO: implement - raise NotImplementedError + def _cursor(self): + """Return a RealDictCursor.""" + return self._conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) # ------------------------------------------------------------------ # Portfolio CRUD # ------------------------------------------------------------------ def create_portfolio(self, portfolio: Portfolio) -> Portfolio: - """Insert a new portfolio row. - - Args: - portfolio: Portfolio instance with all required fields set. - - Returns: - Portfolio with DB-assigned timestamps. - - Raises: - DuplicatePortfolioError: If portfolio_id already exists. - """ - # TODO: implement - raise NotImplementedError + """Insert a new portfolio row.""" + pid = portfolio.portfolio_id or str(uuid.uuid4()) + try: + with self._cursor() as cur: + cur.execute( + """INSERT INTO portfolios + (portfolio_id, name, cash, initial_cash, currency, report_path, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s) + RETURNING *""", + (pid, portfolio.name, portfolio.cash, portfolio.initial_cash, + portfolio.currency, portfolio.report_path, + json.dumps(portfolio.metadata)), + ) + row = cur.fetchone() + except psycopg2.errors.UniqueViolation as exc: + raise DuplicatePortfolioError(f"Portfolio already exists: {pid}") from exc + return self._row_to_portfolio(row) def get_portfolio(self, portfolio_id: str) -> Portfolio: - """Fetch a portfolio by ID. - - Args: - portfolio_id: UUID of the target portfolio. - - Returns: - Portfolio instance. - - Raises: - PortfolioNotFoundError: If no portfolio has that ID. - """ - # TODO: implement - raise NotImplementedError + """Fetch a portfolio by ID.""" + with self._cursor() as cur: + cur.execute("SELECT * FROM portfolios WHERE portfolio_id = %s", (portfolio_id,)) + row = cur.fetchone() + if not row: + raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}") + return self._row_to_portfolio(row) def list_portfolios(self) -> list[Portfolio]: """Return all portfolios ordered by created_at DESC.""" - # TODO: implement - raise NotImplementedError + with self._cursor() as cur: + cur.execute("SELECT * FROM portfolios ORDER BY created_at DESC") + rows = cur.fetchall() + return [self._row_to_portfolio(r) for r in rows] def update_portfolio(self, portfolio: Portfolio) -> Portfolio: - """Update mutable portfolio fields (cash, report_path, metadata). - - Args: - portfolio: Portfolio with updated field values. - - Returns: - Updated Portfolio with refreshed updated_at. - - Raises: - PortfolioNotFoundError: If portfolio_id does not exist. - """ - # TODO: implement - raise NotImplementedError + """Update mutable portfolio fields (cash, report_path, metadata).""" + with self._cursor() as cur: + cur.execute( + """UPDATE portfolios + SET cash = %s, report_path = %s, metadata = %s + WHERE portfolio_id = %s + RETURNING *""", + (portfolio.cash, portfolio.report_path, + json.dumps(portfolio.metadata), portfolio.portfolio_id), + ) + row = cur.fetchone() + if not row: + raise PortfolioNotFoundError(f"Portfolio not found: {portfolio.portfolio_id}") + return self._row_to_portfolio(row) def delete_portfolio(self, portfolio_id: str) -> None: - """Delete a portfolio and all associated data (CASCADE). - - Args: - portfolio_id: UUID of the portfolio to delete. - - Raises: - PortfolioNotFoundError: If portfolio_id does not exist. - """ - # TODO: implement - raise NotImplementedError + """Delete a portfolio and all associated data (CASCADE).""" + with self._cursor() as cur: + cur.execute( + "DELETE FROM portfolios WHERE portfolio_id = %s RETURNING portfolio_id", + (portfolio_id,), + ) + row = cur.fetchone() + if not row: + raise PortfolioNotFoundError(f"Portfolio not found: {portfolio_id}") # ------------------------------------------------------------------ # Holdings CRUD # ------------------------------------------------------------------ def upsert_holding(self, holding: Holding) -> Holding: - """Insert or update a holding row (upsert on portfolio_id + ticker). - - Args: - holding: Holding instance with all required fields set. - - Returns: - Holding with DB-assigned / refreshed timestamps. - """ - # TODO: implement - raise NotImplementedError + """Insert or update a holding row (upsert on portfolio_id + ticker).""" + hid = holding.holding_id or str(uuid.uuid4()) + with self._cursor() as cur: + cur.execute( + """INSERT INTO holdings + (holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique + DO UPDATE SET shares = EXCLUDED.shares, + avg_cost = EXCLUDED.avg_cost, + sector = EXCLUDED.sector, + industry = EXCLUDED.industry + RETURNING *""", + (hid, holding.portfolio_id, holding.ticker.upper(), + holding.shares, holding.avg_cost, holding.sector, holding.industry), + ) + row = cur.fetchone() + return self._row_to_holding(row) def get_holding(self, portfolio_id: str, ticker: str) -> Holding | None: - """Return the holding for (portfolio_id, ticker), or None if not found. - - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol (case-insensitive, stored as uppercase). - """ - # TODO: implement - raise NotImplementedError + """Return the holding for (portfolio_id, ticker), or None.""" + with self._cursor() as cur: + cur.execute( + "SELECT * FROM holdings WHERE portfolio_id = %s AND ticker = %s", + (portfolio_id, ticker.upper()), + ) + row = cur.fetchone() + return self._row_to_holding(row) if row else None def list_holdings(self, portfolio_id: str) -> list[Holding]: - """Return all holdings for a portfolio ordered by cost_basis DESC. - - Args: - portfolio_id: UUID of the target portfolio. - """ - # TODO: implement - raise NotImplementedError + """Return all holdings for a portfolio ordered by cost_basis DESC.""" + with self._cursor() as cur: + cur.execute( + """SELECT * FROM holdings + WHERE portfolio_id = %s + ORDER BY shares * avg_cost DESC""", + (portfolio_id,), + ) + rows = cur.fetchall() + return [self._row_to_holding(r) for r in rows] def delete_holding(self, portfolio_id: str, ticker: str) -> None: - """Delete the holding for (portfolio_id, ticker). - - Args: - portfolio_id: UUID of the target portfolio. - ticker: Ticker symbol. - - Raises: - HoldingNotFoundError: If no such holding exists. - """ - # TODO: implement - raise NotImplementedError + """Delete the holding for (portfolio_id, ticker).""" + with self._cursor() as cur: + cur.execute( + "DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s RETURNING holding_id", + (portfolio_id, ticker.upper()), + ) + row = cur.fetchone() + if not row: + raise HoldingNotFoundError( + f"Holding not found: {ticker} in portfolio {portfolio_id}" + ) # ------------------------------------------------------------------ # Trades # ------------------------------------------------------------------ def record_trade(self, trade: Trade) -> Trade: - """Insert a new trade record. Immutable — no update method. - - Args: - trade: Trade instance with all required fields set. - - Returns: - Trade with DB-assigned trade_id and trade_date. - """ - # TODO: implement - raise NotImplementedError + """Insert a new trade record.""" + tid = trade.trade_id or str(uuid.uuid4()) + with self._cursor() as cur: + cur.execute( + """INSERT INTO trades + (trade_id, portfolio_id, ticker, action, shares, price, + total_value, rationale, signal_source, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + RETURNING *""", + (tid, trade.portfolio_id, trade.ticker, trade.action, + trade.shares, trade.price, trade.total_value, + trade.rationale, trade.signal_source, + json.dumps(trade.metadata)), + ) + row = cur.fetchone() + return self._row_to_trade(row) def list_trades( self, @@ -208,51 +250,142 @@ class SupabaseClient: ticker: str | None = None, limit: int = 100, ) -> list[Trade]: - """Return recent trades for a portfolio, newest first. - - Args: - portfolio_id: Filter by portfolio. - ticker: Optional additional filter by ticker symbol. - limit: Maximum number of rows to return. - """ - # TODO: implement - raise NotImplementedError + """Return recent trades for a portfolio, newest first.""" + if ticker: + query = """SELECT * FROM trades + WHERE portfolio_id = %s AND ticker = %s + ORDER BY trade_date DESC LIMIT %s""" + params = (portfolio_id, ticker.upper(), limit) + else: + query = """SELECT * FROM trades + WHERE portfolio_id = %s + ORDER BY trade_date DESC LIMIT %s""" + params = (portfolio_id, limit) + with self._cursor() as cur: + cur.execute(query, params) + rows = cur.fetchall() + return [self._row_to_trade(r) for r in rows] # ------------------------------------------------------------------ # Snapshots # ------------------------------------------------------------------ def save_snapshot(self, snapshot: PortfolioSnapshot) -> PortfolioSnapshot: - """Insert a new immutable portfolio snapshot. - - Args: - snapshot: PortfolioSnapshot with all required fields set. - - Returns: - Snapshot with DB-assigned snapshot_id and snapshot_date. - """ - # TODO: implement - raise NotImplementedError + """Insert a new immutable portfolio snapshot.""" + sid = snapshot.snapshot_id or str(uuid.uuid4()) + with self._cursor() as cur: + cur.execute( + """INSERT INTO snapshots + (snapshot_id, portfolio_id, total_value, cash, equity_value, + num_positions, holdings_snapshot, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + RETURNING *""", + (sid, snapshot.portfolio_id, snapshot.total_value, + snapshot.cash, snapshot.equity_value, snapshot.num_positions, + json.dumps(snapshot.holdings_snapshot), + json.dumps(snapshot.metadata)), + ) + row = cur.fetchone() + return self._row_to_snapshot(row) def get_latest_snapshot(self, portfolio_id: str) -> PortfolioSnapshot | None: - """Return the most recent snapshot for a portfolio, or None. - - Args: - portfolio_id: UUID of the target portfolio. - """ - # TODO: implement - raise NotImplementedError + """Return the most recent snapshot for a portfolio, or None.""" + with self._cursor() as cur: + cur.execute( + """SELECT * FROM snapshots + WHERE portfolio_id = %s + ORDER BY snapshot_date DESC LIMIT 1""", + (portfolio_id,), + ) + row = cur.fetchone() + return self._row_to_snapshot(row) if row else None def list_snapshots( self, portfolio_id: str, limit: int = 30, ) -> list[PortfolioSnapshot]: - """Return snapshots newest-first up to limit. + """Return snapshots newest-first up to limit.""" + with self._cursor() as cur: + cur.execute( + """SELECT * FROM snapshots + WHERE portfolio_id = %s + ORDER BY snapshot_date DESC LIMIT %s""", + (portfolio_id, limit), + ) + rows = cur.fetchall() + return [self._row_to_snapshot(r) for r in rows] - Args: - portfolio_id: UUID of the target portfolio. - limit: Maximum number of snapshots to return. - """ - # TODO: implement - raise NotImplementedError + # ------------------------------------------------------------------ + # Row -> Model helpers + # ------------------------------------------------------------------ + + @staticmethod + def _row_to_portfolio(row: dict) -> Portfolio: + metadata = row.get("metadata") or {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + return Portfolio( + portfolio_id=str(row["portfolio_id"]), + name=row["name"], + cash=float(row["cash"]), + initial_cash=float(row["initial_cash"]), + currency=row["currency"].strip(), + created_at=str(row["created_at"]), + updated_at=str(row["updated_at"]), + report_path=row.get("report_path"), + metadata=metadata, + ) + + @staticmethod + def _row_to_holding(row: dict) -> Holding: + return Holding( + holding_id=str(row["holding_id"]), + portfolio_id=str(row["portfolio_id"]), + ticker=row["ticker"], + shares=float(row["shares"]), + avg_cost=float(row["avg_cost"]), + sector=row.get("sector"), + industry=row.get("industry"), + created_at=str(row["created_at"]), + updated_at=str(row["updated_at"]), + ) + + @staticmethod + def _row_to_trade(row: dict) -> Trade: + metadata = row.get("metadata") or {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + return Trade( + trade_id=str(row["trade_id"]), + portfolio_id=str(row["portfolio_id"]), + ticker=row["ticker"], + action=row["action"], + shares=float(row["shares"]), + price=float(row["price"]), + total_value=float(row["total_value"]), + trade_date=str(row["trade_date"]), + rationale=row.get("rationale"), + signal_source=row.get("signal_source"), + metadata=metadata, + ) + + @staticmethod + def _row_to_snapshot(row: dict) -> PortfolioSnapshot: + holdings = row.get("holdings_snapshot") or [] + if isinstance(holdings, str): + holdings = json.loads(holdings) + metadata = row.get("metadata") or {} + if isinstance(metadata, str): + metadata = json.loads(metadata) + return PortfolioSnapshot( + snapshot_id=str(row["snapshot_id"]), + portfolio_id=str(row["portfolio_id"]), + snapshot_date=str(row["snapshot_date"]), + total_value=float(row["total_value"]), + cash=float(row["cash"]), + equity_value=float(row["equity_value"]), + num_positions=int(row["num_positions"]), + holdings_snapshot=holdings, + metadata=metadata, + )