feat: implement Portfolio models, ReportStore, and tests; fix SQL constraint

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-03-20 11:16:39 +00:00
parent 7ea9866d1d
commit aa4dcdeb80
6 changed files with 464 additions and 148 deletions

View File

@ -17,12 +17,20 @@ Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when
from __future__ import annotations
import os
import uuid
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from tradingagents.portfolio.models import (
Holding,
Portfolio,
PortfolioSnapshot,
Trade,
)
from tradingagents.portfolio.report_store import ReportStore
from tradingagents.portfolio.supabase_client import SupabaseClient
# ---------------------------------------------------------------------------
# Skip marker for Supabase integration tests
# ---------------------------------------------------------------------------
@ -51,31 +59,72 @@ def sample_holding_id() -> str:
@pytest.fixture
def sample_portfolio(sample_portfolio_id: str):
def sample_portfolio(sample_portfolio_id: str) -> Portfolio:
"""Return an unsaved Portfolio instance for testing."""
# TODO: implement — construct a Portfolio dataclass with test values
raise NotImplementedError
return Portfolio(
portfolio_id=sample_portfolio_id,
name="Test Portfolio",
cash=50_000.0,
initial_cash=100_000.0,
currency="USD",
created_at="2026-03-20T00:00:00Z",
updated_at="2026-03-20T00:00:00Z",
report_path="reports/daily/2026-03-20/portfolio",
metadata={"strategy": "test"},
)
@pytest.fixture
def sample_holding(sample_portfolio_id: str, sample_holding_id: str):
def sample_holding(sample_portfolio_id: str, sample_holding_id: str) -> Holding:
"""Return an unsaved Holding instance for testing."""
# TODO: implement — construct a Holding dataclass with test values
raise NotImplementedError
return Holding(
holding_id=sample_holding_id,
portfolio_id=sample_portfolio_id,
ticker="AAPL",
shares=100.0,
avg_cost=150.0,
sector="Technology",
industry="Consumer Electronics",
created_at="2026-03-20T00:00:00Z",
updated_at="2026-03-20T00:00:00Z",
)
@pytest.fixture
def sample_trade(sample_portfolio_id: str):
def sample_trade(sample_portfolio_id: str) -> Trade:
"""Return an unsaved Trade instance for testing."""
# TODO: implement — construct a Trade dataclass with test values
raise NotImplementedError
return Trade(
trade_id="33333333-3333-3333-3333-333333333333",
portfolio_id=sample_portfolio_id,
ticker="AAPL",
action="BUY",
shares=100.0,
price=150.0,
total_value=15_000.0,
trade_date="2026-03-20T10:00:00Z",
rationale="Strong momentum signal",
signal_source="scanner",
metadata={"confidence": 0.85},
)
@pytest.fixture
def sample_snapshot(sample_portfolio_id: str):
def sample_snapshot(sample_portfolio_id: str) -> PortfolioSnapshot:
"""Return an unsaved PortfolioSnapshot instance for testing."""
# TODO: implement — construct a PortfolioSnapshot dataclass with test values
raise NotImplementedError
return PortfolioSnapshot(
snapshot_id="44444444-4444-4444-4444-444444444444",
portfolio_id=sample_portfolio_id,
snapshot_date="2026-03-20",
total_value=115_000.0,
cash=50_000.0,
equity_value=65_000.0,
num_positions=2,
holdings_snapshot=[
{"ticker": "AAPL", "shares": 100.0, "avg_cost": 150.0},
{"ticker": "MSFT", "shares": 50.0, "avg_cost": 300.0},
],
metadata={"note": "end of day snapshot"},
)
# ---------------------------------------------------------------------------
@ -92,10 +141,9 @@ def tmp_reports(tmp_path: Path) -> Path:
@pytest.fixture
def report_store(tmp_reports: Path):
def report_store(tmp_reports: Path) -> ReportStore:
"""ReportStore instance backed by a temporary directory."""
# TODO: implement — return ReportStore(base_dir=tmp_reports)
raise NotImplementedError
return ReportStore(base_dir=tmp_reports)
# ---------------------------------------------------------------------------
@ -104,7 +152,6 @@ def report_store(tmp_reports: Path):
@pytest.fixture
def mock_supabase_client():
def mock_supabase_client() -> MagicMock:
"""MagicMock of SupabaseClient for unit tests that don't hit the DB."""
# TODO: implement — return MagicMock(spec=SupabaseClient)
raise NotImplementedError
return MagicMock(spec=SupabaseClient)

View File

@ -16,6 +16,13 @@ from __future__ import annotations
import pytest
from tradingagents.portfolio.models import (
Holding,
Portfolio,
PortfolioSnapshot,
Trade,
)
# ---------------------------------------------------------------------------
# Portfolio round-trip
@ -24,19 +31,41 @@ import pytest
def test_portfolio_to_dict_round_trip(sample_portfolio):
"""Portfolio.to_dict() -> Portfolio.from_dict() must be lossless."""
# TODO: implement
# d = sample_portfolio.to_dict()
# restored = Portfolio.from_dict(d)
# assert restored.portfolio_id == sample_portfolio.portfolio_id
# assert restored.cash == sample_portfolio.cash
# ... all stored fields
raise NotImplementedError
d = sample_portfolio.to_dict()
restored = Portfolio.from_dict(d)
assert restored.portfolio_id == sample_portfolio.portfolio_id
assert restored.name == sample_portfolio.name
assert restored.cash == sample_portfolio.cash
assert restored.initial_cash == sample_portfolio.initial_cash
assert restored.currency == sample_portfolio.currency
assert restored.created_at == sample_portfolio.created_at
assert restored.updated_at == sample_portfolio.updated_at
assert restored.report_path == sample_portfolio.report_path
assert restored.metadata == sample_portfolio.metadata
def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
"""to_dict() must not include computed fields (total_value, equity_value, cash_pct)."""
# TODO: implement
raise NotImplementedError
d = sample_portfolio.to_dict()
assert "total_value" not in d
assert "equity_value" not in d
assert "cash_pct" not in d
def test_portfolio_from_dict_defaults_optional_fields():
"""from_dict() must tolerate missing optional fields."""
minimal = {
"portfolio_id": "pid-1",
"name": "Minimal",
"cash": 1000.0,
"initial_cash": 1000.0,
}
p = Portfolio.from_dict(minimal)
assert p.currency == "USD"
assert p.created_at == ""
assert p.updated_at == ""
assert p.report_path is None
assert p.metadata == {}
# ---------------------------------------------------------------------------
@ -46,14 +75,23 @@ def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
def test_holding_to_dict_round_trip(sample_holding):
"""Holding.to_dict() -> Holding.from_dict() must be lossless."""
# TODO: implement
raise NotImplementedError
d = sample_holding.to_dict()
restored = Holding.from_dict(d)
assert restored.holding_id == sample_holding.holding_id
assert restored.portfolio_id == sample_holding.portfolio_id
assert restored.ticker == sample_holding.ticker
assert restored.shares == sample_holding.shares
assert restored.avg_cost == sample_holding.avg_cost
assert restored.sector == sample_holding.sector
assert restored.industry == sample_holding.industry
def test_holding_to_dict_excludes_runtime_fields(sample_holding):
"""to_dict() must not include current_price, current_value, weight, etc."""
# TODO: implement
raise NotImplementedError
d = sample_holding.to_dict()
for field in ("current_price", "current_value", "cost_basis",
"unrealized_pnl", "unrealized_pnl_pct", "weight"):
assert field not in d
# ---------------------------------------------------------------------------
@ -63,8 +101,19 @@ def test_holding_to_dict_excludes_runtime_fields(sample_holding):
def test_trade_to_dict_round_trip(sample_trade):
"""Trade.to_dict() -> Trade.from_dict() must be lossless."""
# TODO: implement
raise NotImplementedError
d = sample_trade.to_dict()
restored = Trade.from_dict(d)
assert restored.trade_id == sample_trade.trade_id
assert restored.portfolio_id == sample_trade.portfolio_id
assert restored.ticker == sample_trade.ticker
assert restored.action == sample_trade.action
assert restored.shares == sample_trade.shares
assert restored.price == sample_trade.price
assert restored.total_value == sample_trade.total_value
assert restored.trade_date == sample_trade.trade_date
assert restored.rationale == sample_trade.rationale
assert restored.signal_source == sample_trade.signal_source
assert restored.metadata == sample_trade.metadata
# ---------------------------------------------------------------------------
@ -74,8 +123,35 @@ def test_trade_to_dict_round_trip(sample_trade):
def test_snapshot_to_dict_round_trip(sample_snapshot):
"""PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip."""
# TODO: implement
raise NotImplementedError
d = sample_snapshot.to_dict()
restored = PortfolioSnapshot.from_dict(d)
assert restored.snapshot_id == sample_snapshot.snapshot_id
assert restored.portfolio_id == sample_snapshot.portfolio_id
assert restored.snapshot_date == sample_snapshot.snapshot_date
assert restored.total_value == sample_snapshot.total_value
assert restored.cash == sample_snapshot.cash
assert restored.equity_value == sample_snapshot.equity_value
assert restored.num_positions == sample_snapshot.num_positions
assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot
assert restored.metadata == sample_snapshot.metadata
def test_snapshot_from_dict_parses_holdings_snapshot_json_string():
"""from_dict() must parse holdings_snapshot when it arrives as a JSON string."""
import json
holdings = [{"ticker": "AAPL", "shares": 10.0}]
data = {
"snapshot_id": "snap-1",
"portfolio_id": "pid-1",
"snapshot_date": "2026-03-20",
"total_value": 110_000.0,
"cash": 10_000.0,
"equity_value": 100_000.0,
"num_positions": 1,
"holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase
}
snap = PortfolioSnapshot.from_dict(data)
assert snap.holdings_snapshot == holdings
# ---------------------------------------------------------------------------
@ -85,34 +161,50 @@ def test_snapshot_to_dict_round_trip(sample_snapshot):
def test_holding_enrich_computes_current_value(sample_holding):
"""enrich() must set current_value = current_price * shares."""
# TODO: implement
# sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
# assert sample_holding.current_value == 200.0 * sample_holding.shares
raise NotImplementedError
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
assert sample_holding.current_value == 200.0 * sample_holding.shares
def test_holding_enrich_computes_unrealized_pnl(sample_holding):
"""enrich() must set unrealized_pnl = current_value - cost_basis."""
# TODO: implement
raise NotImplementedError
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
expected_cost_basis = sample_holding.avg_cost * sample_holding.shares
expected_pnl = sample_holding.current_value - expected_cost_basis
assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl)
def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding):
"""enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis."""
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
cost_basis = sample_holding.avg_cost * sample_holding.shares
expected_pct = sample_holding.unrealized_pnl / cost_basis
assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct)
def test_holding_enrich_computes_weight(sample_holding):
"""enrich() must set weight = current_value / portfolio_total_value."""
# TODO: implement
raise NotImplementedError
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
expected_weight = sample_holding.current_value / 100_000.0
assert sample_holding.weight == pytest.approx(expected_weight)
def test_holding_enrich_returns_self(sample_holding):
"""enrich() must return self for chaining."""
result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
assert result is sample_holding
def test_holding_enrich_handles_zero_cost(sample_holding):
"""When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError)."""
# TODO: implement
raise NotImplementedError
sample_holding.avg_cost = 0.0
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
assert sample_holding.unrealized_pnl_pct == 0.0
def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
"""When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError)."""
# TODO: implement
raise NotImplementedError
sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0)
assert sample_holding.weight == 0.0
# ---------------------------------------------------------------------------
@ -122,11 +214,36 @@ def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding):
"""Portfolio.enrich() must compute total_value = cash + sum(holding.current_value)."""
# TODO: implement
raise NotImplementedError
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
sample_portfolio.enrich([sample_holding])
expected_equity = 200.0 * sample_holding.shares
assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity)
def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding):
"""Portfolio.enrich() must set equity_value = sum(holding.current_value)."""
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
sample_portfolio.enrich([sample_holding])
assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares)
def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding):
"""Portfolio.enrich() must compute cash_pct = cash / total_value."""
# TODO: implement
raise NotImplementedError
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
sample_portfolio.enrich([sample_holding])
expected_pct = sample_portfolio.cash / sample_portfolio.total_value
assert sample_portfolio.cash_pct == pytest.approx(expected_pct)
def test_portfolio_enrich_returns_self(sample_portfolio):
"""enrich() must return self for chaining."""
result = sample_portfolio.enrich([])
assert result is sample_portfolio
def test_portfolio_enrich_no_holdings(sample_portfolio):
"""Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash."""
sample_portfolio.enrich([])
assert sample_portfolio.equity_value == 0.0
assert sample_portfolio.total_value == sample_portfolio.cash
assert sample_portfolio.cash_pct == 1.0

View File

@ -12,10 +12,14 @@ Run::
from __future__ import annotations
import json
from pathlib import Path
import pytest
from tradingagents.portfolio.exceptions import ReportStoreError
from tradingagents.portfolio.report_store import ReportStore
# ---------------------------------------------------------------------------
# Macro scan
@ -24,19 +28,17 @@ import pytest
def test_save_and_load_scan(report_store, tmp_reports):
"""save_scan() then load_scan() must return the original data."""
# TODO: implement
# data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
# path = report_store.save_scan("2026-03-20", data)
# assert path.exists()
# loaded = report_store.load_scan("2026-03-20")
# assert loaded == data
raise NotImplementedError
data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
path = report_store.save_scan("2026-03-20", data)
assert path.exists()
loaded = report_store.load_scan("2026-03-20")
assert loaded == data
def test_load_scan_returns_none_for_missing_file(report_store):
"""load_scan() must return None when the file does not exist."""
# TODO: implement
raise NotImplementedError
result = report_store.load_scan("1900-01-01")
assert result is None
# ---------------------------------------------------------------------------
@ -46,14 +48,21 @@ def test_load_scan_returns_none_for_missing_file(report_store):
def test_save_and_load_analysis(report_store):
"""save_analysis() then load_analysis() must return the original data."""
# TODO: implement
raise NotImplementedError
data = {"ticker": "AAPL", "recommendation": "BUY", "score": 0.92}
report_store.save_analysis("2026-03-20", "AAPL", data)
loaded = report_store.load_analysis("2026-03-20", "AAPL")
assert loaded == data
def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
"""Ticker symbol must be stored as uppercase in the directory path."""
# TODO: implement
raise NotImplementedError
data = {"ticker": "aapl"}
report_store.save_analysis("2026-03-20", "aapl", data)
expected = tmp_reports / "daily" / "2026-03-20" / "AAPL" / "complete_report.json"
assert expected.exists()
# load with lowercase should still work
loaded = report_store.load_analysis("2026-03-20", "aapl")
assert loaded == data
# ---------------------------------------------------------------------------
@ -63,14 +72,16 @@ def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
def test_save_and_load_holding_review(report_store):
"""save_holding_review() then load_holding_review() must round-trip."""
# TODO: implement
raise NotImplementedError
data = {"ticker": "MSFT", "verdict": "HOLD", "price_target": 420.0}
report_store.save_holding_review("2026-03-20", "MSFT", data)
loaded = report_store.load_holding_review("2026-03-20", "MSFT")
assert loaded == data
def test_load_holding_review_returns_none_for_missing(report_store):
"""load_holding_review() must return None when the file does not exist."""
# TODO: implement
raise NotImplementedError
result = report_store.load_holding_review("1900-01-01", "ZZZZ")
assert result is None
# ---------------------------------------------------------------------------
@ -80,8 +91,10 @@ def test_load_holding_review_returns_none_for_missing(report_store):
def test_save_and_load_risk_metrics(report_store):
"""save_risk_metrics() then load_risk_metrics() must round-trip."""
# TODO: implement
raise NotImplementedError
data = {"sharpe": 1.35, "sortino": 1.8, "max_drawdown": -0.12}
report_store.save_risk_metrics("2026-03-20", "pid-123", data)
loaded = report_store.load_risk_metrics("2026-03-20", "pid-123")
assert loaded == data
# ---------------------------------------------------------------------------
@ -91,37 +104,46 @@ def test_save_and_load_risk_metrics(report_store):
def test_save_and_load_pm_decision_json(report_store):
"""save_pm_decision() then load_pm_decision() must round-trip JSON."""
# TODO: implement
# decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
# report_store.save_pm_decision("2026-03-20", "pid-123", decision)
# loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
# assert loaded == decision
raise NotImplementedError
decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
report_store.save_pm_decision("2026-03-20", "pid-123", decision)
loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
assert loaded == decision
def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports):
"""When markdown is passed to save_pm_decision(), .md file must be written."""
# TODO: implement
raise NotImplementedError
decision = {"sells": [], "buys": []}
md_text = "# Decision\n\nHold everything."
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=md_text)
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
assert md_path.exists()
assert md_path.read_text(encoding="utf-8") == md_text
def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports):
"""When markdown=None, no .md file should be written."""
# TODO: implement
raise NotImplementedError
decision = {"sells": [], "buys": []}
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=None)
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
assert not md_path.exists()
def test_load_pm_decision_returns_none_for_missing(report_store):
"""load_pm_decision() must return None when the file does not exist."""
# TODO: implement
raise NotImplementedError
result = report_store.load_pm_decision("1900-01-01", "pid-none")
assert result is None
def test_list_pm_decisions(report_store):
"""list_pm_decisions() must return all saved decision paths, newest first."""
# TODO: implement
# Save decisions for multiple dates, verify order
raise NotImplementedError
dates = ["2026-03-18", "2026-03-19", "2026-03-20"]
for d in dates:
report_store.save_pm_decision(d, "pid-abc", {"date": d})
paths = report_store.list_pm_decisions("pid-abc")
assert len(paths) == 3
# Sorted newest first by ISO date string ordering
date_parts = [p.parent.parent.name for p in paths]
assert date_parts == sorted(dates, reverse=True)
# ---------------------------------------------------------------------------
@ -131,15 +153,24 @@ def test_list_pm_decisions(report_store):
def test_directories_created_on_write(report_store, tmp_reports):
"""Directories must be created automatically on first write."""
# TODO: implement
# assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists()
# report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
# assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir()
raise NotImplementedError
target_dir = tmp_reports / "daily" / "2026-03-20" / "portfolio"
assert not target_dir.exists()
report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
assert target_dir.is_dir()
def test_json_formatted_with_indent(report_store, tmp_reports):
"""Written JSON files must use indent=2 for human readability."""
# TODO: implement
# Write a file, read the raw bytes, verify indentation
raise NotImplementedError
data = {"key": "value", "nested": {"a": 1}}
path = report_store.save_scan("2026-03-20", data)
raw = path.read_text(encoding="utf-8")
# indent=2 means lines like ' "key": ...'
assert ' "key"' in raw
def test_read_json_raises_on_corrupt_file(report_store, tmp_reports):
"""_read_json must raise ReportStoreError for corrupt JSON."""
corrupt = tmp_reports / "corrupt.json"
corrupt.write_text("not valid json{{{", encoding="utf-8")
with pytest.raises(ReportStoreError):
report_store._read_json(corrupt)

View File

@ -65,16 +65,14 @@ CREATE TABLE IF NOT EXISTS trades (
trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE,
ticker TEXT NOT NULL,
action TEXT NOT NULL,
action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')),
shares NUMERIC(18,6) NOT NULL CHECK (shares > 0),
price NUMERIC(18,4) NOT NULL CHECK (price > 0),
total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0),
trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(),
rationale TEXT, -- PM agent rationale for this trade
signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent'
metadata JSONB NOT NULL DEFAULT '{}',
CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL'))
metadata JSONB NOT NULL DEFAULT '{}'
);
COMMENT ON TABLE trades IS

View File

@ -6,11 +6,26 @@ All models are Python ``dataclass`` types with:
- ``from_dict()`` class method for deserialisation
- ``enrich()`` for attaching runtime-computed fields
**float vs Decimal** monetary fields (cash, price, shares, etc.) use plain
``float`` throughout. Rationale:
1. This is **mock trading only** no real money changes hands. The cost of a
subtle floating-point rounding error is zero.
2. All upstream data sources (yfinance, Alpha Vantage, Finnhub) return ``float``
already. Converting to ``Decimal`` at the boundary would require a custom
JSON encoder *and* decoder everywhere, for no practical gain.
3. ``json.dumps`` serialises ``float`` natively; ``Decimal`` raises
``TypeError`` without a custom encoder.
4. If this ever becomes real-money trading, replace ``float`` with
``decimal.Decimal`` and add a ``DecimalEncoder`` the interface is
identical and the change is localised to this file.
See ``docs/portfolio/02_data_models.md`` for full field specifications.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any
@ -48,8 +63,17 @@ class Portfolio:
Runtime-computed fields are excluded.
"""
# TODO: implement
raise NotImplementedError
return {
"portfolio_id": self.portfolio_id,
"name": self.name,
"cash": self.cash,
"initial_cash": self.initial_cash,
"currency": self.currency,
"created_at": self.created_at,
"updated_at": self.updated_at,
"report_path": self.report_path,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Portfolio":
@ -57,8 +81,17 @@ class Portfolio:
Missing optional fields default gracefully. Extra keys are ignored.
"""
# TODO: implement
raise NotImplementedError
return cls(
portfolio_id=data["portfolio_id"],
name=data["name"],
cash=float(data["cash"]),
initial_cash=float(data["initial_cash"]),
currency=data.get("currency", "USD"),
created_at=data.get("created_at", ""),
updated_at=data.get("updated_at", ""),
report_path=data.get("report_path"),
metadata=data.get("metadata") or {},
)
def enrich(self, holdings: list["Holding"]) -> "Portfolio":
"""Compute total_value, equity_value, cash_pct from holdings.
@ -69,8 +102,12 @@ class Portfolio:
holdings: List of Holding objects with current_value populated
(i.e., ``holding.enrich()`` already called).
"""
# TODO: implement
raise NotImplementedError
self.equity_value = sum(
h.current_value for h in holdings if h.current_value is not None
)
self.total_value = self.cash + self.equity_value
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0
return self
# ---------------------------------------------------------------------------
@ -106,14 +143,32 @@ class Holding:
def to_dict(self) -> dict[str, Any]:
"""Serialise stored fields only (runtime-computed fields excluded)."""
# TODO: implement
raise NotImplementedError
return {
"holding_id": self.holding_id,
"portfolio_id": self.portfolio_id,
"ticker": self.ticker,
"shares": self.shares,
"avg_cost": self.avg_cost,
"sector": self.sector,
"industry": self.industry,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Holding":
"""Deserialise from a DB row or JSON dict."""
# TODO: implement
raise NotImplementedError
return cls(
holding_id=data["holding_id"],
portfolio_id=data["portfolio_id"],
ticker=data["ticker"],
shares=float(data["shares"]),
avg_cost=float(data["avg_cost"]),
sector=data.get("sector"),
industry=data.get("industry"),
created_at=data.get("created_at", ""),
updated_at=data.get("updated_at", ""),
)
def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding":
"""Populate runtime-computed fields in-place and return self.
@ -129,8 +184,17 @@ class Holding:
current_price: Latest market price for this ticker.
portfolio_total_value: Total portfolio value (cash + equity).
"""
# TODO: implement
raise NotImplementedError
self.current_price = current_price
self.current_value = current_price * self.shares
self.cost_basis = self.avg_cost * self.shares
self.unrealized_pnl = self.current_value - self.cost_basis
self.unrealized_pnl_pct = (
self.unrealized_pnl / self.cost_basis if self.cost_basis != 0.0 else 0.0
)
self.weight = (
self.current_value / portfolio_total_value if portfolio_total_value != 0.0 else 0.0
)
return self
# ---------------------------------------------------------------------------
@ -159,14 +223,36 @@ class Trade:
def to_dict(self) -> dict[str, Any]:
"""Serialise all fields."""
# TODO: implement
raise NotImplementedError
return {
"trade_id": self.trade_id,
"portfolio_id": self.portfolio_id,
"ticker": self.ticker,
"action": self.action,
"shares": self.shares,
"price": self.price,
"total_value": self.total_value,
"trade_date": self.trade_date,
"rationale": self.rationale,
"signal_source": self.signal_source,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Trade":
"""Deserialise from a DB row or JSON dict."""
# TODO: implement
raise NotImplementedError
return cls(
trade_id=data["trade_id"],
portfolio_id=data["portfolio_id"],
ticker=data["ticker"],
action=data["action"],
shares=float(data["shares"]),
price=float(data["price"]),
total_value=float(data["total_value"]),
trade_date=data.get("trade_date", ""),
rationale=data.get("rationale"),
signal_source=data.get("signal_source"),
metadata=data.get("metadata") or {},
)
# ---------------------------------------------------------------------------
@ -194,8 +280,17 @@ class PortfolioSnapshot:
def to_dict(self) -> dict[str, Any]:
"""Serialise all fields. ``holdings_snapshot`` is already a list[dict]."""
# TODO: implement
raise NotImplementedError
return {
"snapshot_id": self.snapshot_id,
"portfolio_id": self.portfolio_id,
"snapshot_date": self.snapshot_date,
"total_value": self.total_value,
"cash": self.cash,
"equity_value": self.equity_value,
"num_positions": self.num_positions,
"holdings_snapshot": self.holdings_snapshot,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot":
@ -203,5 +298,17 @@ class PortfolioSnapshot:
``holdings_snapshot`` is parsed from a JSON string when needed.
"""
# TODO: implement
raise NotImplementedError
holdings_snapshot = data.get("holdings_snapshot", [])
if isinstance(holdings_snapshot, str):
holdings_snapshot = json.loads(holdings_snapshot)
return cls(
snapshot_id=data["snapshot_id"],
portfolio_id=data["portfolio_id"],
snapshot_date=data["snapshot_date"],
total_value=float(data["total_value"]),
cash=float(data["cash"]),
equity_value=float(data["equity_value"]),
num_positions=int(data["num_positions"]),
holdings_snapshot=holdings_snapshot,
metadata=data.get("metadata") or {},
)

View File

@ -28,9 +28,12 @@ Usage::
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from tradingagents.portfolio.exceptions import ReportStoreError
class ReportStore:
"""Filesystem document store for all portfolio-related reports.
@ -48,8 +51,7 @@ class ReportStore:
Override via the ``PORTFOLIO_DATA_DIR`` env var or
``get_portfolio_config()["data_dir"]``.
"""
# TODO: implement — store Path(base_dir), resolve as needed
raise NotImplementedError
self._base_dir = Path(base_dir)
# ------------------------------------------------------------------
# Internal helpers
@ -60,8 +62,7 @@ class ReportStore:
Path: ``{base_dir}/daily/{date}/portfolio/``
"""
# TODO: implement
raise NotImplementedError
return self._base_dir / "daily" / date / "portfolio"
def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
"""Write a dict to a JSON file, creating parent directories as needed.
@ -76,8 +77,12 @@ class ReportStore:
Raises:
ReportStoreError: On filesystem write failure.
"""
# TODO: implement
raise NotImplementedError
try:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
return path
except OSError as exc:
raise ReportStoreError(f"Failed to write {path}: {exc}") from exc
def _read_json(self, path: Path) -> dict[str, Any] | None:
"""Read a JSON file, returning None if the file does not exist.
@ -85,8 +90,12 @@ class ReportStore:
Raises:
ReportStoreError: On JSON parse error (file exists but is corrupt).
"""
# TODO: implement
raise NotImplementedError
if not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise ReportStoreError(f"Corrupt JSON at {path}: {exc}") from exc
# ------------------------------------------------------------------
# Macro Scan
@ -104,13 +113,13 @@ class ReportStore:
Returns:
Path of the written file.
"""
# TODO: implement
raise NotImplementedError
path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
return self._write_json(path, data)
def load_scan(self, date: str) -> dict[str, Any] | None:
"""Load macro scan summary. Returns None if the file does not exist."""
# TODO: implement
raise NotImplementedError
path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
return self._read_json(path)
# ------------------------------------------------------------------
# Per-Ticker Analysis
@ -126,13 +135,13 @@ class ReportStore:
ticker: Ticker symbol (stored as uppercase).
data: Analysis output dict.
"""
# TODO: implement
raise NotImplementedError
path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
return self._write_json(path, data)
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
"""Load per-ticker analysis JSON. Returns None if the file does not exist."""
# TODO: implement
raise NotImplementedError
path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
return self._read_json(path)
# ------------------------------------------------------------------
# Holding Reviews
@ -153,13 +162,13 @@ class ReportStore:
ticker: Ticker symbol (stored as uppercase).
data: HoldingReviewerAgent output dict.
"""
# TODO: implement
raise NotImplementedError
path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
return self._write_json(path, data)
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
"""Load holding review output. Returns None if the file does not exist."""
# TODO: implement
raise NotImplementedError
path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
return self._read_json(path)
# ------------------------------------------------------------------
# Risk Metrics
@ -180,8 +189,8 @@ class ReportStore:
portfolio_id: UUID of the target portfolio.
data: Risk metrics dict (Sharpe, Sortino, VaR, etc.).
"""
# TODO: implement
raise NotImplementedError
path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
return self._write_json(path, data)
def load_risk_metrics(
self,
@ -189,8 +198,8 @@ class ReportStore:
portfolio_id: str,
) -> dict[str, Any] | None:
"""Load risk metrics. Returns None if the file does not exist."""
# TODO: implement
raise NotImplementedError
path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
return self._read_json(path)
# ------------------------------------------------------------------
# PM Decisions
@ -218,8 +227,15 @@ class ReportStore:
Returns:
Path of the written JSON file.
"""
# TODO: implement
raise NotImplementedError
json_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
self._write_json(json_path, data)
if markdown is not None:
md_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.md"
try:
md_path.write_text(markdown, encoding="utf-8")
except OSError as exc:
raise ReportStoreError(f"Failed to write {md_path}: {exc}") from exc
return json_path
def load_pm_decision(
self,
@ -227,8 +243,8 @@ class ReportStore:
portfolio_id: str,
) -> dict[str, Any] | None:
"""Load PM decision JSON. Returns None if the file does not exist."""
# TODO: implement
raise NotImplementedError
path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
return self._read_json(path)
def list_pm_decisions(self, portfolio_id: str) -> list[Path]:
"""Return all saved PM decision JSON paths for portfolio_id, newest first.
@ -241,5 +257,5 @@ class ReportStore:
Returns:
Sorted list of Path objects, newest date first.
"""
# TODO: implement
raise NotImplementedError
pattern = f"daily/*/portfolio/{portfolio_id}_pm_decision.json"
return sorted(self._base_dir.glob(pattern), reverse=True)