From aa4dcdeb806684d312392b0a0d66b0424edc9c65 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Mar 2026 11:16:39 +0000 Subject: [PATCH] feat: implement Portfolio models, ReportStore, and tests; fix SQL constraint Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com> --- tests/portfolio/conftest.py | 85 ++++++-- tests/portfolio/test_models.py | 183 ++++++++++++++---- tests/portfolio/test_report_store.py | 115 +++++++---- .../migrations/001_initial_schema.sql | 6 +- tradingagents/portfolio/models.py | 147 ++++++++++++-- tradingagents/portfolio/report_store.py | 76 +++++--- 6 files changed, 464 insertions(+), 148 deletions(-) diff --git a/tests/portfolio/conftest.py b/tests/portfolio/conftest.py index bbbf0d78..073d318d 100644 --- a/tests/portfolio/conftest.py +++ b/tests/portfolio/conftest.py @@ -17,12 +17,20 @@ Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when from __future__ import annotations import os -import uuid from pathlib import Path from unittest.mock import MagicMock import pytest +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) +from tradingagents.portfolio.report_store import ReportStore +from tradingagents.portfolio.supabase_client import SupabaseClient + # --------------------------------------------------------------------------- # Skip marker for Supabase integration tests # --------------------------------------------------------------------------- @@ -51,31 +59,72 @@ def sample_holding_id() -> str: @pytest.fixture -def sample_portfolio(sample_portfolio_id: str): +def sample_portfolio(sample_portfolio_id: str) -> Portfolio: """Return an unsaved Portfolio instance for testing.""" - # TODO: implement — construct a Portfolio dataclass with test values - raise NotImplementedError + return Portfolio( + portfolio_id=sample_portfolio_id, + name="Test Portfolio", + cash=50_000.0, + initial_cash=100_000.0, + currency="USD", + created_at="2026-03-20T00:00:00Z", + updated_at="2026-03-20T00:00:00Z", + report_path="reports/daily/2026-03-20/portfolio", + metadata={"strategy": "test"}, + ) @pytest.fixture -def sample_holding(sample_portfolio_id: str, sample_holding_id: str): +def sample_holding(sample_portfolio_id: str, sample_holding_id: str) -> Holding: """Return an unsaved Holding instance for testing.""" - # TODO: implement — construct a Holding dataclass with test values - raise NotImplementedError + return Holding( + holding_id=sample_holding_id, + portfolio_id=sample_portfolio_id, + ticker="AAPL", + shares=100.0, + avg_cost=150.0, + sector="Technology", + industry="Consumer Electronics", + created_at="2026-03-20T00:00:00Z", + updated_at="2026-03-20T00:00:00Z", + ) @pytest.fixture -def sample_trade(sample_portfolio_id: str): +def sample_trade(sample_portfolio_id: str) -> Trade: """Return an unsaved Trade instance for testing.""" - # TODO: implement — construct a Trade dataclass with test values - raise NotImplementedError + return Trade( + trade_id="33333333-3333-3333-3333-333333333333", + portfolio_id=sample_portfolio_id, + ticker="AAPL", + action="BUY", + shares=100.0, + price=150.0, + total_value=15_000.0, + trade_date="2026-03-20T10:00:00Z", + rationale="Strong momentum signal", + signal_source="scanner", + metadata={"confidence": 0.85}, + ) @pytest.fixture -def sample_snapshot(sample_portfolio_id: str): +def sample_snapshot(sample_portfolio_id: str) -> PortfolioSnapshot: """Return an unsaved PortfolioSnapshot instance for testing.""" - # TODO: implement — construct a PortfolioSnapshot dataclass with test values - raise NotImplementedError + return PortfolioSnapshot( + snapshot_id="44444444-4444-4444-4444-444444444444", + portfolio_id=sample_portfolio_id, + snapshot_date="2026-03-20", + total_value=115_000.0, + cash=50_000.0, + equity_value=65_000.0, + num_positions=2, + holdings_snapshot=[ + {"ticker": "AAPL", "shares": 100.0, "avg_cost": 150.0}, + {"ticker": "MSFT", "shares": 50.0, "avg_cost": 300.0}, + ], + metadata={"note": "end of day snapshot"}, + ) # --------------------------------------------------------------------------- @@ -92,10 +141,9 @@ def tmp_reports(tmp_path: Path) -> Path: @pytest.fixture -def report_store(tmp_reports: Path): +def report_store(tmp_reports: Path) -> ReportStore: """ReportStore instance backed by a temporary directory.""" - # TODO: implement — return ReportStore(base_dir=tmp_reports) - raise NotImplementedError + return ReportStore(base_dir=tmp_reports) # --------------------------------------------------------------------------- @@ -104,7 +152,6 @@ def report_store(tmp_reports: Path): @pytest.fixture -def mock_supabase_client(): +def mock_supabase_client() -> MagicMock: """MagicMock of SupabaseClient for unit tests that don't hit the DB.""" - # TODO: implement — return MagicMock(spec=SupabaseClient) - raise NotImplementedError + return MagicMock(spec=SupabaseClient) diff --git a/tests/portfolio/test_models.py b/tests/portfolio/test_models.py index 72f3c798..bcd2258a 100644 --- a/tests/portfolio/test_models.py +++ b/tests/portfolio/test_models.py @@ -16,6 +16,13 @@ from __future__ import annotations import pytest +from tradingagents.portfolio.models import ( + Holding, + Portfolio, + PortfolioSnapshot, + Trade, +) + # --------------------------------------------------------------------------- # Portfolio round-trip @@ -24,19 +31,41 @@ import pytest def test_portfolio_to_dict_round_trip(sample_portfolio): """Portfolio.to_dict() -> Portfolio.from_dict() must be lossless.""" - # TODO: implement - # d = sample_portfolio.to_dict() - # restored = Portfolio.from_dict(d) - # assert restored.portfolio_id == sample_portfolio.portfolio_id - # assert restored.cash == sample_portfolio.cash - # ... all stored fields - raise NotImplementedError + d = sample_portfolio.to_dict() + restored = Portfolio.from_dict(d) + assert restored.portfolio_id == sample_portfolio.portfolio_id + assert restored.name == sample_portfolio.name + assert restored.cash == sample_portfolio.cash + assert restored.initial_cash == sample_portfolio.initial_cash + assert restored.currency == sample_portfolio.currency + assert restored.created_at == sample_portfolio.created_at + assert restored.updated_at == sample_portfolio.updated_at + assert restored.report_path == sample_portfolio.report_path + assert restored.metadata == sample_portfolio.metadata def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio): """to_dict() must not include computed fields (total_value, equity_value, cash_pct).""" - # TODO: implement - raise NotImplementedError + d = sample_portfolio.to_dict() + assert "total_value" not in d + assert "equity_value" not in d + assert "cash_pct" not in d + + +def test_portfolio_from_dict_defaults_optional_fields(): + """from_dict() must tolerate missing optional fields.""" + minimal = { + "portfolio_id": "pid-1", + "name": "Minimal", + "cash": 1000.0, + "initial_cash": 1000.0, + } + p = Portfolio.from_dict(minimal) + assert p.currency == "USD" + assert p.created_at == "" + assert p.updated_at == "" + assert p.report_path is None + assert p.metadata == {} # --------------------------------------------------------------------------- @@ -46,14 +75,23 @@ def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio): def test_holding_to_dict_round_trip(sample_holding): """Holding.to_dict() -> Holding.from_dict() must be lossless.""" - # TODO: implement - raise NotImplementedError + d = sample_holding.to_dict() + restored = Holding.from_dict(d) + assert restored.holding_id == sample_holding.holding_id + assert restored.portfolio_id == sample_holding.portfolio_id + assert restored.ticker == sample_holding.ticker + assert restored.shares == sample_holding.shares + assert restored.avg_cost == sample_holding.avg_cost + assert restored.sector == sample_holding.sector + assert restored.industry == sample_holding.industry def test_holding_to_dict_excludes_runtime_fields(sample_holding): """to_dict() must not include current_price, current_value, weight, etc.""" - # TODO: implement - raise NotImplementedError + d = sample_holding.to_dict() + for field in ("current_price", "current_value", "cost_basis", + "unrealized_pnl", "unrealized_pnl_pct", "weight"): + assert field not in d # --------------------------------------------------------------------------- @@ -63,8 +101,19 @@ def test_holding_to_dict_excludes_runtime_fields(sample_holding): def test_trade_to_dict_round_trip(sample_trade): """Trade.to_dict() -> Trade.from_dict() must be lossless.""" - # TODO: implement - raise NotImplementedError + d = sample_trade.to_dict() + restored = Trade.from_dict(d) + assert restored.trade_id == sample_trade.trade_id + assert restored.portfolio_id == sample_trade.portfolio_id + assert restored.ticker == sample_trade.ticker + assert restored.action == sample_trade.action + assert restored.shares == sample_trade.shares + assert restored.price == sample_trade.price + assert restored.total_value == sample_trade.total_value + assert restored.trade_date == sample_trade.trade_date + assert restored.rationale == sample_trade.rationale + assert restored.signal_source == sample_trade.signal_source + assert restored.metadata == sample_trade.metadata # --------------------------------------------------------------------------- @@ -74,8 +123,35 @@ def test_trade_to_dict_round_trip(sample_trade): def test_snapshot_to_dict_round_trip(sample_snapshot): """PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip.""" - # TODO: implement - raise NotImplementedError + d = sample_snapshot.to_dict() + restored = PortfolioSnapshot.from_dict(d) + assert restored.snapshot_id == sample_snapshot.snapshot_id + assert restored.portfolio_id == sample_snapshot.portfolio_id + assert restored.snapshot_date == sample_snapshot.snapshot_date + assert restored.total_value == sample_snapshot.total_value + assert restored.cash == sample_snapshot.cash + assert restored.equity_value == sample_snapshot.equity_value + assert restored.num_positions == sample_snapshot.num_positions + assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot + assert restored.metadata == sample_snapshot.metadata + + +def test_snapshot_from_dict_parses_holdings_snapshot_json_string(): + """from_dict() must parse holdings_snapshot when it arrives as a JSON string.""" + import json + holdings = [{"ticker": "AAPL", "shares": 10.0}] + data = { + "snapshot_id": "snap-1", + "portfolio_id": "pid-1", + "snapshot_date": "2026-03-20", + "total_value": 110_000.0, + "cash": 10_000.0, + "equity_value": 100_000.0, + "num_positions": 1, + "holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase + } + snap = PortfolioSnapshot.from_dict(data) + assert snap.holdings_snapshot == holdings # --------------------------------------------------------------------------- @@ -85,34 +161,50 @@ def test_snapshot_to_dict_round_trip(sample_snapshot): def test_holding_enrich_computes_current_value(sample_holding): """enrich() must set current_value = current_price * shares.""" - # TODO: implement - # sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) - # assert sample_holding.current_value == 200.0 * sample_holding.shares - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + assert sample_holding.current_value == 200.0 * sample_holding.shares def test_holding_enrich_computes_unrealized_pnl(sample_holding): """enrich() must set unrealized_pnl = current_value - cost_basis.""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + expected_cost_basis = sample_holding.avg_cost * sample_holding.shares + expected_pnl = sample_holding.current_value - expected_cost_basis + assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl) + + +def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding): + """enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis.""" + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + cost_basis = sample_holding.avg_cost * sample_holding.shares + expected_pct = sample_holding.unrealized_pnl / cost_basis + assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct) def test_holding_enrich_computes_weight(sample_holding): """enrich() must set weight = current_value / portfolio_total_value.""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + expected_weight = sample_holding.current_value / 100_000.0 + assert sample_holding.weight == pytest.approx(expected_weight) + + +def test_holding_enrich_returns_self(sample_holding): + """enrich() must return self for chaining.""" + result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + assert result is sample_holding def test_holding_enrich_handles_zero_cost(sample_holding): """When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError).""" - # TODO: implement - raise NotImplementedError + sample_holding.avg_cost = 0.0 + sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) + assert sample_holding.unrealized_pnl_pct == 0.0 def test_holding_enrich_handles_zero_portfolio_value(sample_holding): """When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError).""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0) + assert sample_holding.weight == 0.0 # --------------------------------------------------------------------------- @@ -122,11 +214,36 @@ def test_holding_enrich_handles_zero_portfolio_value(sample_holding): def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding): """Portfolio.enrich() must compute total_value = cash + sum(holding.current_value).""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich() + sample_portfolio.enrich([sample_holding]) + expected_equity = 200.0 * sample_holding.shares + assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity) + + +def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding): + """Portfolio.enrich() must set equity_value = sum(holding.current_value).""" + sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich() + sample_portfolio.enrich([sample_holding]) + assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares) def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding): """Portfolio.enrich() must compute cash_pct = cash / total_value.""" - # TODO: implement - raise NotImplementedError + sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich() + sample_portfolio.enrich([sample_holding]) + expected_pct = sample_portfolio.cash / sample_portfolio.total_value + assert sample_portfolio.cash_pct == pytest.approx(expected_pct) + + +def test_portfolio_enrich_returns_self(sample_portfolio): + """enrich() must return self for chaining.""" + result = sample_portfolio.enrich([]) + assert result is sample_portfolio + + +def test_portfolio_enrich_no_holdings(sample_portfolio): + """Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash.""" + sample_portfolio.enrich([]) + assert sample_portfolio.equity_value == 0.0 + assert sample_portfolio.total_value == sample_portfolio.cash + assert sample_portfolio.cash_pct == 1.0 diff --git a/tests/portfolio/test_report_store.py b/tests/portfolio/test_report_store.py index ba111799..9f042a26 100644 --- a/tests/portfolio/test_report_store.py +++ b/tests/portfolio/test_report_store.py @@ -12,10 +12,14 @@ Run:: from __future__ import annotations +import json from pathlib import Path import pytest +from tradingagents.portfolio.exceptions import ReportStoreError +from tradingagents.portfolio.report_store import ReportStore + # --------------------------------------------------------------------------- # Macro scan @@ -24,19 +28,17 @@ import pytest def test_save_and_load_scan(report_store, tmp_reports): """save_scan() then load_scan() must return the original data.""" - # TODO: implement - # data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"} - # path = report_store.save_scan("2026-03-20", data) - # assert path.exists() - # loaded = report_store.load_scan("2026-03-20") - # assert loaded == data - raise NotImplementedError + data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"} + path = report_store.save_scan("2026-03-20", data) + assert path.exists() + loaded = report_store.load_scan("2026-03-20") + assert loaded == data def test_load_scan_returns_none_for_missing_file(report_store): """load_scan() must return None when the file does not exist.""" - # TODO: implement - raise NotImplementedError + result = report_store.load_scan("1900-01-01") + assert result is None # --------------------------------------------------------------------------- @@ -46,14 +48,21 @@ def test_load_scan_returns_none_for_missing_file(report_store): def test_save_and_load_analysis(report_store): """save_analysis() then load_analysis() must return the original data.""" - # TODO: implement - raise NotImplementedError + data = {"ticker": "AAPL", "recommendation": "BUY", "score": 0.92} + report_store.save_analysis("2026-03-20", "AAPL", data) + loaded = report_store.load_analysis("2026-03-20", "AAPL") + assert loaded == data def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports): """Ticker symbol must be stored as uppercase in the directory path.""" - # TODO: implement - raise NotImplementedError + data = {"ticker": "aapl"} + report_store.save_analysis("2026-03-20", "aapl", data) + expected = tmp_reports / "daily" / "2026-03-20" / "AAPL" / "complete_report.json" + assert expected.exists() + # load with lowercase should still work + loaded = report_store.load_analysis("2026-03-20", "aapl") + assert loaded == data # --------------------------------------------------------------------------- @@ -63,14 +72,16 @@ def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports): def test_save_and_load_holding_review(report_store): """save_holding_review() then load_holding_review() must round-trip.""" - # TODO: implement - raise NotImplementedError + data = {"ticker": "MSFT", "verdict": "HOLD", "price_target": 420.0} + report_store.save_holding_review("2026-03-20", "MSFT", data) + loaded = report_store.load_holding_review("2026-03-20", "MSFT") + assert loaded == data def test_load_holding_review_returns_none_for_missing(report_store): """load_holding_review() must return None when the file does not exist.""" - # TODO: implement - raise NotImplementedError + result = report_store.load_holding_review("1900-01-01", "ZZZZ") + assert result is None # --------------------------------------------------------------------------- @@ -80,8 +91,10 @@ def test_load_holding_review_returns_none_for_missing(report_store): def test_save_and_load_risk_metrics(report_store): """save_risk_metrics() then load_risk_metrics() must round-trip.""" - # TODO: implement - raise NotImplementedError + data = {"sharpe": 1.35, "sortino": 1.8, "max_drawdown": -0.12} + report_store.save_risk_metrics("2026-03-20", "pid-123", data) + loaded = report_store.load_risk_metrics("2026-03-20", "pid-123") + assert loaded == data # --------------------------------------------------------------------------- @@ -91,37 +104,46 @@ def test_save_and_load_risk_metrics(report_store): def test_save_and_load_pm_decision_json(report_store): """save_pm_decision() then load_pm_decision() must round-trip JSON.""" - # TODO: implement - # decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]} - # report_store.save_pm_decision("2026-03-20", "pid-123", decision) - # loaded = report_store.load_pm_decision("2026-03-20", "pid-123") - # assert loaded == decision - raise NotImplementedError + decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]} + report_store.save_pm_decision("2026-03-20", "pid-123", decision) + loaded = report_store.load_pm_decision("2026-03-20", "pid-123") + assert loaded == decision def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports): """When markdown is passed to save_pm_decision(), .md file must be written.""" - # TODO: implement - raise NotImplementedError + decision = {"sells": [], "buys": []} + md_text = "# Decision\n\nHold everything." + report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=md_text) + md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md" + assert md_path.exists() + assert md_path.read_text(encoding="utf-8") == md_text def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports): """When markdown=None, no .md file should be written.""" - # TODO: implement - raise NotImplementedError + decision = {"sells": [], "buys": []} + report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=None) + md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md" + assert not md_path.exists() def test_load_pm_decision_returns_none_for_missing(report_store): """load_pm_decision() must return None when the file does not exist.""" - # TODO: implement - raise NotImplementedError + result = report_store.load_pm_decision("1900-01-01", "pid-none") + assert result is None def test_list_pm_decisions(report_store): """list_pm_decisions() must return all saved decision paths, newest first.""" - # TODO: implement - # Save decisions for multiple dates, verify order - raise NotImplementedError + dates = ["2026-03-18", "2026-03-19", "2026-03-20"] + for d in dates: + report_store.save_pm_decision(d, "pid-abc", {"date": d}) + paths = report_store.list_pm_decisions("pid-abc") + assert len(paths) == 3 + # Sorted newest first by ISO date string ordering + date_parts = [p.parent.parent.name for p in paths] + assert date_parts == sorted(dates, reverse=True) # --------------------------------------------------------------------------- @@ -131,15 +153,24 @@ def test_list_pm_decisions(report_store): def test_directories_created_on_write(report_store, tmp_reports): """Directories must be created automatically on first write.""" - # TODO: implement - # assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists() - # report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2}) - # assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir() - raise NotImplementedError + target_dir = tmp_reports / "daily" / "2026-03-20" / "portfolio" + assert not target_dir.exists() + report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2}) + assert target_dir.is_dir() def test_json_formatted_with_indent(report_store, tmp_reports): """Written JSON files must use indent=2 for human readability.""" - # TODO: implement - # Write a file, read the raw bytes, verify indentation - raise NotImplementedError + data = {"key": "value", "nested": {"a": 1}} + path = report_store.save_scan("2026-03-20", data) + raw = path.read_text(encoding="utf-8") + # indent=2 means lines like ' "key": ...' + assert ' "key"' in raw + + +def test_read_json_raises_on_corrupt_file(report_store, tmp_reports): + """_read_json must raise ReportStoreError for corrupt JSON.""" + corrupt = tmp_reports / "corrupt.json" + corrupt.write_text("not valid json{{{", encoding="utf-8") + with pytest.raises(ReportStoreError): + report_store._read_json(corrupt) diff --git a/tradingagents/portfolio/migrations/001_initial_schema.sql b/tradingagents/portfolio/migrations/001_initial_schema.sql index 42724b7a..6260d9d7 100644 --- a/tradingagents/portfolio/migrations/001_initial_schema.sql +++ b/tradingagents/portfolio/migrations/001_initial_schema.sql @@ -65,16 +65,14 @@ CREATE TABLE IF NOT EXISTS trades ( trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, ticker TEXT NOT NULL, - action TEXT NOT NULL, + action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')), shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), price NUMERIC(18,4) NOT NULL CHECK (price > 0), total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0), trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), rationale TEXT, -- PM agent rationale for this trade signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent' - metadata JSONB NOT NULL DEFAULT '{}', - - CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL')) + metadata JSONB NOT NULL DEFAULT '{}' ); COMMENT ON TABLE trades IS diff --git a/tradingagents/portfolio/models.py b/tradingagents/portfolio/models.py index 9a61fe9b..0cbfd442 100644 --- a/tradingagents/portfolio/models.py +++ b/tradingagents/portfolio/models.py @@ -6,11 +6,26 @@ All models are Python ``dataclass`` types with: - ``from_dict()`` class method for deserialisation - ``enrich()`` for attaching runtime-computed fields +**float vs Decimal** — monetary fields (cash, price, shares, etc.) use plain +``float`` throughout. Rationale: + +1. This is **mock trading only** — no real money changes hands. The cost of a + subtle floating-point rounding error is zero. +2. All upstream data sources (yfinance, Alpha Vantage, Finnhub) return ``float`` + already. Converting to ``Decimal`` at the boundary would require a custom + JSON encoder *and* decoder everywhere, for no practical gain. +3. ``json.dumps`` serialises ``float`` natively; ``Decimal`` raises + ``TypeError`` without a custom encoder. +4. If this ever becomes real-money trading, replace ``float`` with + ``decimal.Decimal`` and add a ``DecimalEncoder`` — the interface is + identical and the change is localised to this file. + See ``docs/portfolio/02_data_models.md`` for full field specifications. """ from __future__ import annotations +import json from dataclasses import dataclass, field from typing import Any @@ -48,8 +63,17 @@ class Portfolio: Runtime-computed fields are excluded. """ - # TODO: implement - raise NotImplementedError + return { + "portfolio_id": self.portfolio_id, + "name": self.name, + "cash": self.cash, + "initial_cash": self.initial_cash, + "currency": self.currency, + "created_at": self.created_at, + "updated_at": self.updated_at, + "report_path": self.report_path, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "Portfolio": @@ -57,8 +81,17 @@ class Portfolio: Missing optional fields default gracefully. Extra keys are ignored. """ - # TODO: implement - raise NotImplementedError + return cls( + portfolio_id=data["portfolio_id"], + name=data["name"], + cash=float(data["cash"]), + initial_cash=float(data["initial_cash"]), + currency=data.get("currency", "USD"), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + report_path=data.get("report_path"), + metadata=data.get("metadata") or {}, + ) def enrich(self, holdings: list["Holding"]) -> "Portfolio": """Compute total_value, equity_value, cash_pct from holdings. @@ -69,8 +102,12 @@ class Portfolio: holdings: List of Holding objects with current_value populated (i.e., ``holding.enrich()`` already called). """ - # TODO: implement - raise NotImplementedError + self.equity_value = sum( + h.current_value for h in holdings if h.current_value is not None + ) + self.total_value = self.cash + self.equity_value + self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0 + return self # --------------------------------------------------------------------------- @@ -106,14 +143,32 @@ class Holding: def to_dict(self) -> dict[str, Any]: """Serialise stored fields only (runtime-computed fields excluded).""" - # TODO: implement - raise NotImplementedError + return { + "holding_id": self.holding_id, + "portfolio_id": self.portfolio_id, + "ticker": self.ticker, + "shares": self.shares, + "avg_cost": self.avg_cost, + "sector": self.sector, + "industry": self.industry, + "created_at": self.created_at, + "updated_at": self.updated_at, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "Holding": """Deserialise from a DB row or JSON dict.""" - # TODO: implement - raise NotImplementedError + return cls( + holding_id=data["holding_id"], + portfolio_id=data["portfolio_id"], + ticker=data["ticker"], + shares=float(data["shares"]), + avg_cost=float(data["avg_cost"]), + sector=data.get("sector"), + industry=data.get("industry"), + created_at=data.get("created_at", ""), + updated_at=data.get("updated_at", ""), + ) def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding": """Populate runtime-computed fields in-place and return self. @@ -129,8 +184,17 @@ class Holding: current_price: Latest market price for this ticker. portfolio_total_value: Total portfolio value (cash + equity). """ - # TODO: implement - raise NotImplementedError + self.current_price = current_price + self.current_value = current_price * self.shares + self.cost_basis = self.avg_cost * self.shares + self.unrealized_pnl = self.current_value - self.cost_basis + self.unrealized_pnl_pct = ( + self.unrealized_pnl / self.cost_basis if self.cost_basis != 0.0 else 0.0 + ) + self.weight = ( + self.current_value / portfolio_total_value if portfolio_total_value != 0.0 else 0.0 + ) + return self # --------------------------------------------------------------------------- @@ -159,14 +223,36 @@ class Trade: def to_dict(self) -> dict[str, Any]: """Serialise all fields.""" - # TODO: implement - raise NotImplementedError + return { + "trade_id": self.trade_id, + "portfolio_id": self.portfolio_id, + "ticker": self.ticker, + "action": self.action, + "shares": self.shares, + "price": self.price, + "total_value": self.total_value, + "trade_date": self.trade_date, + "rationale": self.rationale, + "signal_source": self.signal_source, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "Trade": """Deserialise from a DB row or JSON dict.""" - # TODO: implement - raise NotImplementedError + return cls( + trade_id=data["trade_id"], + portfolio_id=data["portfolio_id"], + ticker=data["ticker"], + action=data["action"], + shares=float(data["shares"]), + price=float(data["price"]), + total_value=float(data["total_value"]), + trade_date=data.get("trade_date", ""), + rationale=data.get("rationale"), + signal_source=data.get("signal_source"), + metadata=data.get("metadata") or {}, + ) # --------------------------------------------------------------------------- @@ -194,8 +280,17 @@ class PortfolioSnapshot: def to_dict(self) -> dict[str, Any]: """Serialise all fields. ``holdings_snapshot`` is already a list[dict].""" - # TODO: implement - raise NotImplementedError + return { + "snapshot_id": self.snapshot_id, + "portfolio_id": self.portfolio_id, + "snapshot_date": self.snapshot_date, + "total_value": self.total_value, + "cash": self.cash, + "equity_value": self.equity_value, + "num_positions": self.num_positions, + "holdings_snapshot": self.holdings_snapshot, + "metadata": self.metadata, + } @classmethod def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot": @@ -203,5 +298,17 @@ class PortfolioSnapshot: ``holdings_snapshot`` is parsed from a JSON string when needed. """ - # TODO: implement - raise NotImplementedError + holdings_snapshot = data.get("holdings_snapshot", []) + if isinstance(holdings_snapshot, str): + holdings_snapshot = json.loads(holdings_snapshot) + return cls( + snapshot_id=data["snapshot_id"], + portfolio_id=data["portfolio_id"], + snapshot_date=data["snapshot_date"], + total_value=float(data["total_value"]), + cash=float(data["cash"]), + equity_value=float(data["equity_value"]), + num_positions=int(data["num_positions"]), + holdings_snapshot=holdings_snapshot, + metadata=data.get("metadata") or {}, + ) diff --git a/tradingagents/portfolio/report_store.py b/tradingagents/portfolio/report_store.py index 5ed8594a..2d641693 100644 --- a/tradingagents/portfolio/report_store.py +++ b/tradingagents/portfolio/report_store.py @@ -28,9 +28,12 @@ Usage:: from __future__ import annotations +import json from pathlib import Path from typing import Any +from tradingagents.portfolio.exceptions import ReportStoreError + class ReportStore: """Filesystem document store for all portfolio-related reports. @@ -48,8 +51,7 @@ class ReportStore: Override via the ``PORTFOLIO_DATA_DIR`` env var or ``get_portfolio_config()["data_dir"]``. """ - # TODO: implement — store Path(base_dir), resolve as needed - raise NotImplementedError + self._base_dir = Path(base_dir) # ------------------------------------------------------------------ # Internal helpers @@ -60,8 +62,7 @@ class ReportStore: Path: ``{base_dir}/daily/{date}/portfolio/`` """ - # TODO: implement - raise NotImplementedError + return self._base_dir / "daily" / date / "portfolio" def _write_json(self, path: Path, data: dict[str, Any]) -> Path: """Write a dict to a JSON file, creating parent directories as needed. @@ -76,8 +77,12 @@ class ReportStore: Raises: ReportStoreError: On filesystem write failure. """ - # TODO: implement - raise NotImplementedError + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2), encoding="utf-8") + return path + except OSError as exc: + raise ReportStoreError(f"Failed to write {path}: {exc}") from exc def _read_json(self, path: Path) -> dict[str, Any] | None: """Read a JSON file, returning None if the file does not exist. @@ -85,8 +90,12 @@ class ReportStore: Raises: ReportStoreError: On JSON parse error (file exists but is corrupt). """ - # TODO: implement - raise NotImplementedError + if not path.exists(): + return None + try: + return json.loads(path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ReportStoreError(f"Corrupt JSON at {path}: {exc}") from exc # ------------------------------------------------------------------ # Macro Scan @@ -104,13 +113,13 @@ class ReportStore: Returns: Path of the written file. """ - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json" + return self._write_json(path, data) def load_scan(self, date: str) -> dict[str, Any] | None: """Load macro scan summary. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json" + return self._read_json(path) # ------------------------------------------------------------------ # Per-Ticker Analysis @@ -126,13 +135,13 @@ class ReportStore: ticker: Ticker symbol (stored as uppercase). data: Analysis output dict. """ - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json" + return self._write_json(path, data) def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None: """Load per-ticker analysis JSON. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json" + return self._read_json(path) # ------------------------------------------------------------------ # Holding Reviews @@ -153,13 +162,13 @@ class ReportStore: ticker: Ticker symbol (stored as uppercase). data: HoldingReviewerAgent output dict. """ - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json" + return self._write_json(path, data) def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None: """Load holding review output. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json" + return self._read_json(path) # ------------------------------------------------------------------ # Risk Metrics @@ -180,8 +189,8 @@ class ReportStore: portfolio_id: UUID of the target portfolio. data: Risk metrics dict (Sharpe, Sortino, VaR, etc.). """ - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json" + return self._write_json(path, data) def load_risk_metrics( self, @@ -189,8 +198,8 @@ class ReportStore: portfolio_id: str, ) -> dict[str, Any] | None: """Load risk metrics. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json" + return self._read_json(path) # ------------------------------------------------------------------ # PM Decisions @@ -218,8 +227,15 @@ class ReportStore: Returns: Path of the written JSON file. """ - # TODO: implement - raise NotImplementedError + json_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json" + self._write_json(json_path, data) + if markdown is not None: + md_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.md" + try: + md_path.write_text(markdown, encoding="utf-8") + except OSError as exc: + raise ReportStoreError(f"Failed to write {md_path}: {exc}") from exc + return json_path def load_pm_decision( self, @@ -227,8 +243,8 @@ class ReportStore: portfolio_id: str, ) -> dict[str, Any] | None: """Load PM decision JSON. Returns None if the file does not exist.""" - # TODO: implement - raise NotImplementedError + path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json" + return self._read_json(path) def list_pm_decisions(self, portfolio_id: str) -> list[Path]: """Return all saved PM decision JSON paths for portfolio_id, newest first. @@ -241,5 +257,5 @@ class ReportStore: Returns: Sorted list of Path objects, newest date first. """ - # TODO: implement - raise NotImplementedError + pattern = f"daily/*/portfolio/{portfolio_id}_pm_decision.json" + return sorted(self._base_dir.glob(pattern), reverse=True)