feat: implement Portfolio models, ReportStore, and tests; fix SQL constraint

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-03-20 11:16:39 +00:00
parent 7ea9866d1d
commit aa4dcdeb80
6 changed files with 464 additions and 148 deletions

View File

@ -17,12 +17,20 @@ Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when
from __future__ import annotations from __future__ import annotations
import os import os
import uuid
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from tradingagents.portfolio.models import (
Holding,
Portfolio,
PortfolioSnapshot,
Trade,
)
from tradingagents.portfolio.report_store import ReportStore
from tradingagents.portfolio.supabase_client import SupabaseClient
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Skip marker for Supabase integration tests # Skip marker for Supabase integration tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -51,31 +59,72 @@ def sample_holding_id() -> str:
@pytest.fixture @pytest.fixture
def sample_portfolio(sample_portfolio_id: str): def sample_portfolio(sample_portfolio_id: str) -> Portfolio:
"""Return an unsaved Portfolio instance for testing.""" """Return an unsaved Portfolio instance for testing."""
# TODO: implement — construct a Portfolio dataclass with test values return Portfolio(
raise NotImplementedError portfolio_id=sample_portfolio_id,
name="Test Portfolio",
cash=50_000.0,
initial_cash=100_000.0,
currency="USD",
created_at="2026-03-20T00:00:00Z",
updated_at="2026-03-20T00:00:00Z",
report_path="reports/daily/2026-03-20/portfolio",
metadata={"strategy": "test"},
)
@pytest.fixture @pytest.fixture
def sample_holding(sample_portfolio_id: str, sample_holding_id: str): def sample_holding(sample_portfolio_id: str, sample_holding_id: str) -> Holding:
"""Return an unsaved Holding instance for testing.""" """Return an unsaved Holding instance for testing."""
# TODO: implement — construct a Holding dataclass with test values return Holding(
raise NotImplementedError holding_id=sample_holding_id,
portfolio_id=sample_portfolio_id,
ticker="AAPL",
shares=100.0,
avg_cost=150.0,
sector="Technology",
industry="Consumer Electronics",
created_at="2026-03-20T00:00:00Z",
updated_at="2026-03-20T00:00:00Z",
)
@pytest.fixture @pytest.fixture
def sample_trade(sample_portfolio_id: str): def sample_trade(sample_portfolio_id: str) -> Trade:
"""Return an unsaved Trade instance for testing.""" """Return an unsaved Trade instance for testing."""
# TODO: implement — construct a Trade dataclass with test values return Trade(
raise NotImplementedError trade_id="33333333-3333-3333-3333-333333333333",
portfolio_id=sample_portfolio_id,
ticker="AAPL",
action="BUY",
shares=100.0,
price=150.0,
total_value=15_000.0,
trade_date="2026-03-20T10:00:00Z",
rationale="Strong momentum signal",
signal_source="scanner",
metadata={"confidence": 0.85},
)
@pytest.fixture @pytest.fixture
def sample_snapshot(sample_portfolio_id: str): def sample_snapshot(sample_portfolio_id: str) -> PortfolioSnapshot:
"""Return an unsaved PortfolioSnapshot instance for testing.""" """Return an unsaved PortfolioSnapshot instance for testing."""
# TODO: implement — construct a PortfolioSnapshot dataclass with test values return PortfolioSnapshot(
raise NotImplementedError snapshot_id="44444444-4444-4444-4444-444444444444",
portfolio_id=sample_portfolio_id,
snapshot_date="2026-03-20",
total_value=115_000.0,
cash=50_000.0,
equity_value=65_000.0,
num_positions=2,
holdings_snapshot=[
{"ticker": "AAPL", "shares": 100.0, "avg_cost": 150.0},
{"ticker": "MSFT", "shares": 50.0, "avg_cost": 300.0},
],
metadata={"note": "end of day snapshot"},
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -92,10 +141,9 @@ def tmp_reports(tmp_path: Path) -> Path:
@pytest.fixture @pytest.fixture
def report_store(tmp_reports: Path): def report_store(tmp_reports: Path) -> ReportStore:
"""ReportStore instance backed by a temporary directory.""" """ReportStore instance backed by a temporary directory."""
# TODO: implement — return ReportStore(base_dir=tmp_reports) return ReportStore(base_dir=tmp_reports)
raise NotImplementedError
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -104,7 +152,6 @@ def report_store(tmp_reports: Path):
@pytest.fixture @pytest.fixture
def mock_supabase_client(): def mock_supabase_client() -> MagicMock:
"""MagicMock of SupabaseClient for unit tests that don't hit the DB.""" """MagicMock of SupabaseClient for unit tests that don't hit the DB."""
# TODO: implement — return MagicMock(spec=SupabaseClient) return MagicMock(spec=SupabaseClient)
raise NotImplementedError

View File

@ -16,6 +16,13 @@ from __future__ import annotations
import pytest import pytest
from tradingagents.portfolio.models import (
Holding,
Portfolio,
PortfolioSnapshot,
Trade,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Portfolio round-trip # Portfolio round-trip
@ -24,19 +31,41 @@ import pytest
def test_portfolio_to_dict_round_trip(sample_portfolio): def test_portfolio_to_dict_round_trip(sample_portfolio):
"""Portfolio.to_dict() -> Portfolio.from_dict() must be lossless.""" """Portfolio.to_dict() -> Portfolio.from_dict() must be lossless."""
# TODO: implement d = sample_portfolio.to_dict()
# d = sample_portfolio.to_dict() restored = Portfolio.from_dict(d)
# restored = Portfolio.from_dict(d) assert restored.portfolio_id == sample_portfolio.portfolio_id
# assert restored.portfolio_id == sample_portfolio.portfolio_id assert restored.name == sample_portfolio.name
# assert restored.cash == sample_portfolio.cash assert restored.cash == sample_portfolio.cash
# ... all stored fields assert restored.initial_cash == sample_portfolio.initial_cash
raise NotImplementedError assert restored.currency == sample_portfolio.currency
assert restored.created_at == sample_portfolio.created_at
assert restored.updated_at == sample_portfolio.updated_at
assert restored.report_path == sample_portfolio.report_path
assert restored.metadata == sample_portfolio.metadata
def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio): def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
"""to_dict() must not include computed fields (total_value, equity_value, cash_pct).""" """to_dict() must not include computed fields (total_value, equity_value, cash_pct)."""
# TODO: implement d = sample_portfolio.to_dict()
raise NotImplementedError assert "total_value" not in d
assert "equity_value" not in d
assert "cash_pct" not in d
def test_portfolio_from_dict_defaults_optional_fields():
"""from_dict() must tolerate missing optional fields."""
minimal = {
"portfolio_id": "pid-1",
"name": "Minimal",
"cash": 1000.0,
"initial_cash": 1000.0,
}
p = Portfolio.from_dict(minimal)
assert p.currency == "USD"
assert p.created_at == ""
assert p.updated_at == ""
assert p.report_path is None
assert p.metadata == {}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -46,14 +75,23 @@ def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
def test_holding_to_dict_round_trip(sample_holding): def test_holding_to_dict_round_trip(sample_holding):
"""Holding.to_dict() -> Holding.from_dict() must be lossless.""" """Holding.to_dict() -> Holding.from_dict() must be lossless."""
# TODO: implement d = sample_holding.to_dict()
raise NotImplementedError restored = Holding.from_dict(d)
assert restored.holding_id == sample_holding.holding_id
assert restored.portfolio_id == sample_holding.portfolio_id
assert restored.ticker == sample_holding.ticker
assert restored.shares == sample_holding.shares
assert restored.avg_cost == sample_holding.avg_cost
assert restored.sector == sample_holding.sector
assert restored.industry == sample_holding.industry
def test_holding_to_dict_excludes_runtime_fields(sample_holding): def test_holding_to_dict_excludes_runtime_fields(sample_holding):
"""to_dict() must not include current_price, current_value, weight, etc.""" """to_dict() must not include current_price, current_value, weight, etc."""
# TODO: implement d = sample_holding.to_dict()
raise NotImplementedError for field in ("current_price", "current_value", "cost_basis",
"unrealized_pnl", "unrealized_pnl_pct", "weight"):
assert field not in d
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -63,8 +101,19 @@ def test_holding_to_dict_excludes_runtime_fields(sample_holding):
def test_trade_to_dict_round_trip(sample_trade): def test_trade_to_dict_round_trip(sample_trade):
"""Trade.to_dict() -> Trade.from_dict() must be lossless.""" """Trade.to_dict() -> Trade.from_dict() must be lossless."""
# TODO: implement d = sample_trade.to_dict()
raise NotImplementedError restored = Trade.from_dict(d)
assert restored.trade_id == sample_trade.trade_id
assert restored.portfolio_id == sample_trade.portfolio_id
assert restored.ticker == sample_trade.ticker
assert restored.action == sample_trade.action
assert restored.shares == sample_trade.shares
assert restored.price == sample_trade.price
assert restored.total_value == sample_trade.total_value
assert restored.trade_date == sample_trade.trade_date
assert restored.rationale == sample_trade.rationale
assert restored.signal_source == sample_trade.signal_source
assert restored.metadata == sample_trade.metadata
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -74,8 +123,35 @@ def test_trade_to_dict_round_trip(sample_trade):
def test_snapshot_to_dict_round_trip(sample_snapshot): def test_snapshot_to_dict_round_trip(sample_snapshot):
"""PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip.""" """PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip."""
# TODO: implement d = sample_snapshot.to_dict()
raise NotImplementedError restored = PortfolioSnapshot.from_dict(d)
assert restored.snapshot_id == sample_snapshot.snapshot_id
assert restored.portfolio_id == sample_snapshot.portfolio_id
assert restored.snapshot_date == sample_snapshot.snapshot_date
assert restored.total_value == sample_snapshot.total_value
assert restored.cash == sample_snapshot.cash
assert restored.equity_value == sample_snapshot.equity_value
assert restored.num_positions == sample_snapshot.num_positions
assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot
assert restored.metadata == sample_snapshot.metadata
def test_snapshot_from_dict_parses_holdings_snapshot_json_string():
"""from_dict() must parse holdings_snapshot when it arrives as a JSON string."""
import json
holdings = [{"ticker": "AAPL", "shares": 10.0}]
data = {
"snapshot_id": "snap-1",
"portfolio_id": "pid-1",
"snapshot_date": "2026-03-20",
"total_value": 110_000.0,
"cash": 10_000.0,
"equity_value": 100_000.0,
"num_positions": 1,
"holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase
}
snap = PortfolioSnapshot.from_dict(data)
assert snap.holdings_snapshot == holdings
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -85,34 +161,50 @@ def test_snapshot_to_dict_round_trip(sample_snapshot):
def test_holding_enrich_computes_current_value(sample_holding): def test_holding_enrich_computes_current_value(sample_holding):
"""enrich() must set current_value = current_price * shares.""" """enrich() must set current_value = current_price * shares."""
# TODO: implement sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
# sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0) assert sample_holding.current_value == 200.0 * sample_holding.shares
# assert sample_holding.current_value == 200.0 * sample_holding.shares
raise NotImplementedError
def test_holding_enrich_computes_unrealized_pnl(sample_holding): def test_holding_enrich_computes_unrealized_pnl(sample_holding):
"""enrich() must set unrealized_pnl = current_value - cost_basis.""" """enrich() must set unrealized_pnl = current_value - cost_basis."""
# TODO: implement sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
raise NotImplementedError expected_cost_basis = sample_holding.avg_cost * sample_holding.shares
expected_pnl = sample_holding.current_value - expected_cost_basis
assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl)
def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding):
"""enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis."""
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
cost_basis = sample_holding.avg_cost * sample_holding.shares
expected_pct = sample_holding.unrealized_pnl / cost_basis
assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct)
def test_holding_enrich_computes_weight(sample_holding): def test_holding_enrich_computes_weight(sample_holding):
"""enrich() must set weight = current_value / portfolio_total_value.""" """enrich() must set weight = current_value / portfolio_total_value."""
# TODO: implement sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
raise NotImplementedError expected_weight = sample_holding.current_value / 100_000.0
assert sample_holding.weight == pytest.approx(expected_weight)
def test_holding_enrich_returns_self(sample_holding):
"""enrich() must return self for chaining."""
result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
assert result is sample_holding
def test_holding_enrich_handles_zero_cost(sample_holding): def test_holding_enrich_handles_zero_cost(sample_holding):
"""When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError).""" """When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError)."""
# TODO: implement sample_holding.avg_cost = 0.0
raise NotImplementedError sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
assert sample_holding.unrealized_pnl_pct == 0.0
def test_holding_enrich_handles_zero_portfolio_value(sample_holding): def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
"""When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError).""" """When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError)."""
# TODO: implement sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0)
raise NotImplementedError assert sample_holding.weight == 0.0
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -122,11 +214,36 @@ def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding): def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding):
"""Portfolio.enrich() must compute total_value = cash + sum(holding.current_value).""" """Portfolio.enrich() must compute total_value = cash + sum(holding.current_value)."""
# TODO: implement sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
raise NotImplementedError sample_portfolio.enrich([sample_holding])
expected_equity = 200.0 * sample_holding.shares
assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity)
def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding):
"""Portfolio.enrich() must set equity_value = sum(holding.current_value)."""
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
sample_portfolio.enrich([sample_holding])
assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares)
def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding): def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding):
"""Portfolio.enrich() must compute cash_pct = cash / total_value.""" """Portfolio.enrich() must compute cash_pct = cash / total_value."""
# TODO: implement sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
raise NotImplementedError sample_portfolio.enrich([sample_holding])
expected_pct = sample_portfolio.cash / sample_portfolio.total_value
assert sample_portfolio.cash_pct == pytest.approx(expected_pct)
def test_portfolio_enrich_returns_self(sample_portfolio):
"""enrich() must return self for chaining."""
result = sample_portfolio.enrich([])
assert result is sample_portfolio
def test_portfolio_enrich_no_holdings(sample_portfolio):
"""Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash."""
sample_portfolio.enrich([])
assert sample_portfolio.equity_value == 0.0
assert sample_portfolio.total_value == sample_portfolio.cash
assert sample_portfolio.cash_pct == 1.0

View File

@ -12,10 +12,14 @@ Run::
from __future__ import annotations from __future__ import annotations
import json
from pathlib import Path from pathlib import Path
import pytest import pytest
from tradingagents.portfolio.exceptions import ReportStoreError
from tradingagents.portfolio.report_store import ReportStore
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Macro scan # Macro scan
@ -24,19 +28,17 @@ import pytest
def test_save_and_load_scan(report_store, tmp_reports): def test_save_and_load_scan(report_store, tmp_reports):
"""save_scan() then load_scan() must return the original data.""" """save_scan() then load_scan() must return the original data."""
# TODO: implement data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
# data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"} path = report_store.save_scan("2026-03-20", data)
# path = report_store.save_scan("2026-03-20", data) assert path.exists()
# assert path.exists() loaded = report_store.load_scan("2026-03-20")
# loaded = report_store.load_scan("2026-03-20") assert loaded == data
# assert loaded == data
raise NotImplementedError
def test_load_scan_returns_none_for_missing_file(report_store): def test_load_scan_returns_none_for_missing_file(report_store):
"""load_scan() must return None when the file does not exist.""" """load_scan() must return None when the file does not exist."""
# TODO: implement result = report_store.load_scan("1900-01-01")
raise NotImplementedError assert result is None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -46,14 +48,21 @@ def test_load_scan_returns_none_for_missing_file(report_store):
def test_save_and_load_analysis(report_store): def test_save_and_load_analysis(report_store):
"""save_analysis() then load_analysis() must return the original data.""" """save_analysis() then load_analysis() must return the original data."""
# TODO: implement data = {"ticker": "AAPL", "recommendation": "BUY", "score": 0.92}
raise NotImplementedError report_store.save_analysis("2026-03-20", "AAPL", data)
loaded = report_store.load_analysis("2026-03-20", "AAPL")
assert loaded == data
def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports): def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
"""Ticker symbol must be stored as uppercase in the directory path.""" """Ticker symbol must be stored as uppercase in the directory path."""
# TODO: implement data = {"ticker": "aapl"}
raise NotImplementedError report_store.save_analysis("2026-03-20", "aapl", data)
expected = tmp_reports / "daily" / "2026-03-20" / "AAPL" / "complete_report.json"
assert expected.exists()
# load with lowercase should still work
loaded = report_store.load_analysis("2026-03-20", "aapl")
assert loaded == data
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -63,14 +72,16 @@ def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
def test_save_and_load_holding_review(report_store): def test_save_and_load_holding_review(report_store):
"""save_holding_review() then load_holding_review() must round-trip.""" """save_holding_review() then load_holding_review() must round-trip."""
# TODO: implement data = {"ticker": "MSFT", "verdict": "HOLD", "price_target": 420.0}
raise NotImplementedError report_store.save_holding_review("2026-03-20", "MSFT", data)
loaded = report_store.load_holding_review("2026-03-20", "MSFT")
assert loaded == data
def test_load_holding_review_returns_none_for_missing(report_store): def test_load_holding_review_returns_none_for_missing(report_store):
"""load_holding_review() must return None when the file does not exist.""" """load_holding_review() must return None when the file does not exist."""
# TODO: implement result = report_store.load_holding_review("1900-01-01", "ZZZZ")
raise NotImplementedError assert result is None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -80,8 +91,10 @@ def test_load_holding_review_returns_none_for_missing(report_store):
def test_save_and_load_risk_metrics(report_store): def test_save_and_load_risk_metrics(report_store):
"""save_risk_metrics() then load_risk_metrics() must round-trip.""" """save_risk_metrics() then load_risk_metrics() must round-trip."""
# TODO: implement data = {"sharpe": 1.35, "sortino": 1.8, "max_drawdown": -0.12}
raise NotImplementedError report_store.save_risk_metrics("2026-03-20", "pid-123", data)
loaded = report_store.load_risk_metrics("2026-03-20", "pid-123")
assert loaded == data
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -91,37 +104,46 @@ def test_save_and_load_risk_metrics(report_store):
def test_save_and_load_pm_decision_json(report_store): def test_save_and_load_pm_decision_json(report_store):
"""save_pm_decision() then load_pm_decision() must round-trip JSON.""" """save_pm_decision() then load_pm_decision() must round-trip JSON."""
# TODO: implement decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
# decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]} report_store.save_pm_decision("2026-03-20", "pid-123", decision)
# report_store.save_pm_decision("2026-03-20", "pid-123", decision) loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
# loaded = report_store.load_pm_decision("2026-03-20", "pid-123") assert loaded == decision
# assert loaded == decision
raise NotImplementedError
def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports): def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports):
"""When markdown is passed to save_pm_decision(), .md file must be written.""" """When markdown is passed to save_pm_decision(), .md file must be written."""
# TODO: implement decision = {"sells": [], "buys": []}
raise NotImplementedError md_text = "# Decision\n\nHold everything."
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=md_text)
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
assert md_path.exists()
assert md_path.read_text(encoding="utf-8") == md_text
def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports): def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports):
"""When markdown=None, no .md file should be written.""" """When markdown=None, no .md file should be written."""
# TODO: implement decision = {"sells": [], "buys": []}
raise NotImplementedError report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=None)
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
assert not md_path.exists()
def test_load_pm_decision_returns_none_for_missing(report_store): def test_load_pm_decision_returns_none_for_missing(report_store):
"""load_pm_decision() must return None when the file does not exist.""" """load_pm_decision() must return None when the file does not exist."""
# TODO: implement result = report_store.load_pm_decision("1900-01-01", "pid-none")
raise NotImplementedError assert result is None
def test_list_pm_decisions(report_store): def test_list_pm_decisions(report_store):
"""list_pm_decisions() must return all saved decision paths, newest first.""" """list_pm_decisions() must return all saved decision paths, newest first."""
# TODO: implement dates = ["2026-03-18", "2026-03-19", "2026-03-20"]
# Save decisions for multiple dates, verify order for d in dates:
raise NotImplementedError report_store.save_pm_decision(d, "pid-abc", {"date": d})
paths = report_store.list_pm_decisions("pid-abc")
assert len(paths) == 3
# Sorted newest first by ISO date string ordering
date_parts = [p.parent.parent.name for p in paths]
assert date_parts == sorted(dates, reverse=True)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -131,15 +153,24 @@ def test_list_pm_decisions(report_store):
def test_directories_created_on_write(report_store, tmp_reports): def test_directories_created_on_write(report_store, tmp_reports):
"""Directories must be created automatically on first write.""" """Directories must be created automatically on first write."""
# TODO: implement target_dir = tmp_reports / "daily" / "2026-03-20" / "portfolio"
# assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists() assert not target_dir.exists()
# report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2}) report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
# assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir() assert target_dir.is_dir()
raise NotImplementedError
def test_json_formatted_with_indent(report_store, tmp_reports): def test_json_formatted_with_indent(report_store, tmp_reports):
"""Written JSON files must use indent=2 for human readability.""" """Written JSON files must use indent=2 for human readability."""
# TODO: implement data = {"key": "value", "nested": {"a": 1}}
# Write a file, read the raw bytes, verify indentation path = report_store.save_scan("2026-03-20", data)
raise NotImplementedError raw = path.read_text(encoding="utf-8")
# indent=2 means lines like ' "key": ...'
assert ' "key"' in raw
def test_read_json_raises_on_corrupt_file(report_store, tmp_reports):
"""_read_json must raise ReportStoreError for corrupt JSON."""
corrupt = tmp_reports / "corrupt.json"
corrupt.write_text("not valid json{{{", encoding="utf-8")
with pytest.raises(ReportStoreError):
report_store._read_json(corrupt)

View File

@ -65,16 +65,14 @@ CREATE TABLE IF NOT EXISTS trades (
trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(), trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE, portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE,
ticker TEXT NOT NULL, ticker TEXT NOT NULL,
action TEXT NOT NULL, action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')),
shares NUMERIC(18,6) NOT NULL CHECK (shares > 0), shares NUMERIC(18,6) NOT NULL CHECK (shares > 0),
price NUMERIC(18,4) NOT NULL CHECK (price > 0), price NUMERIC(18,4) NOT NULL CHECK (price > 0),
total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0), total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0),
trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(), trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(),
rationale TEXT, -- PM agent rationale for this trade rationale TEXT, -- PM agent rationale for this trade
signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent' signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent'
metadata JSONB NOT NULL DEFAULT '{}', metadata JSONB NOT NULL DEFAULT '{}'
CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL'))
); );
COMMENT ON TABLE trades IS COMMENT ON TABLE trades IS

View File

@ -6,11 +6,26 @@ All models are Python ``dataclass`` types with:
- ``from_dict()`` class method for deserialisation - ``from_dict()`` class method for deserialisation
- ``enrich()`` for attaching runtime-computed fields - ``enrich()`` for attaching runtime-computed fields
**float vs Decimal** monetary fields (cash, price, shares, etc.) use plain
``float`` throughout. Rationale:
1. This is **mock trading only** no real money changes hands. The cost of a
subtle floating-point rounding error is zero.
2. All upstream data sources (yfinance, Alpha Vantage, Finnhub) return ``float``
already. Converting to ``Decimal`` at the boundary would require a custom
JSON encoder *and* decoder everywhere, for no practical gain.
3. ``json.dumps`` serialises ``float`` natively; ``Decimal`` raises
``TypeError`` without a custom encoder.
4. If this ever becomes real-money trading, replace ``float`` with
``decimal.Decimal`` and add a ``DecimalEncoder`` the interface is
identical and the change is localised to this file.
See ``docs/portfolio/02_data_models.md`` for full field specifications. See ``docs/portfolio/02_data_models.md`` for full field specifications.
""" """
from __future__ import annotations from __future__ import annotations
import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@ -48,8 +63,17 @@ class Portfolio:
Runtime-computed fields are excluded. Runtime-computed fields are excluded.
""" """
# TODO: implement return {
raise NotImplementedError "portfolio_id": self.portfolio_id,
"name": self.name,
"cash": self.cash,
"initial_cash": self.initial_cash,
"currency": self.currency,
"created_at": self.created_at,
"updated_at": self.updated_at,
"report_path": self.report_path,
"metadata": self.metadata,
}
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "Portfolio": def from_dict(cls, data: dict[str, Any]) -> "Portfolio":
@ -57,8 +81,17 @@ class Portfolio:
Missing optional fields default gracefully. Extra keys are ignored. Missing optional fields default gracefully. Extra keys are ignored.
""" """
# TODO: implement return cls(
raise NotImplementedError portfolio_id=data["portfolio_id"],
name=data["name"],
cash=float(data["cash"]),
initial_cash=float(data["initial_cash"]),
currency=data.get("currency", "USD"),
created_at=data.get("created_at", ""),
updated_at=data.get("updated_at", ""),
report_path=data.get("report_path"),
metadata=data.get("metadata") or {},
)
def enrich(self, holdings: list["Holding"]) -> "Portfolio": def enrich(self, holdings: list["Holding"]) -> "Portfolio":
"""Compute total_value, equity_value, cash_pct from holdings. """Compute total_value, equity_value, cash_pct from holdings.
@ -69,8 +102,12 @@ class Portfolio:
holdings: List of Holding objects with current_value populated holdings: List of Holding objects with current_value populated
(i.e., ``holding.enrich()`` already called). (i.e., ``holding.enrich()`` already called).
""" """
# TODO: implement self.equity_value = sum(
raise NotImplementedError h.current_value for h in holdings if h.current_value is not None
)
self.total_value = self.cash + self.equity_value
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0
return self
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -106,14 +143,32 @@ class Holding:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Serialise stored fields only (runtime-computed fields excluded).""" """Serialise stored fields only (runtime-computed fields excluded)."""
# TODO: implement return {
raise NotImplementedError "holding_id": self.holding_id,
"portfolio_id": self.portfolio_id,
"ticker": self.ticker,
"shares": self.shares,
"avg_cost": self.avg_cost,
"sector": self.sector,
"industry": self.industry,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "Holding": def from_dict(cls, data: dict[str, Any]) -> "Holding":
"""Deserialise from a DB row or JSON dict.""" """Deserialise from a DB row or JSON dict."""
# TODO: implement return cls(
raise NotImplementedError holding_id=data["holding_id"],
portfolio_id=data["portfolio_id"],
ticker=data["ticker"],
shares=float(data["shares"]),
avg_cost=float(data["avg_cost"]),
sector=data.get("sector"),
industry=data.get("industry"),
created_at=data.get("created_at", ""),
updated_at=data.get("updated_at", ""),
)
def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding": def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding":
"""Populate runtime-computed fields in-place and return self. """Populate runtime-computed fields in-place and return self.
@ -129,8 +184,17 @@ class Holding:
current_price: Latest market price for this ticker. current_price: Latest market price for this ticker.
portfolio_total_value: Total portfolio value (cash + equity). portfolio_total_value: Total portfolio value (cash + equity).
""" """
# TODO: implement self.current_price = current_price
raise NotImplementedError self.current_value = current_price * self.shares
self.cost_basis = self.avg_cost * self.shares
self.unrealized_pnl = self.current_value - self.cost_basis
self.unrealized_pnl_pct = (
self.unrealized_pnl / self.cost_basis if self.cost_basis != 0.0 else 0.0
)
self.weight = (
self.current_value / portfolio_total_value if portfolio_total_value != 0.0 else 0.0
)
return self
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -159,14 +223,36 @@ class Trade:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Serialise all fields.""" """Serialise all fields."""
# TODO: implement return {
raise NotImplementedError "trade_id": self.trade_id,
"portfolio_id": self.portfolio_id,
"ticker": self.ticker,
"action": self.action,
"shares": self.shares,
"price": self.price,
"total_value": self.total_value,
"trade_date": self.trade_date,
"rationale": self.rationale,
"signal_source": self.signal_source,
"metadata": self.metadata,
}
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "Trade": def from_dict(cls, data: dict[str, Any]) -> "Trade":
"""Deserialise from a DB row or JSON dict.""" """Deserialise from a DB row or JSON dict."""
# TODO: implement return cls(
raise NotImplementedError trade_id=data["trade_id"],
portfolio_id=data["portfolio_id"],
ticker=data["ticker"],
action=data["action"],
shares=float(data["shares"]),
price=float(data["price"]),
total_value=float(data["total_value"]),
trade_date=data.get("trade_date", ""),
rationale=data.get("rationale"),
signal_source=data.get("signal_source"),
metadata=data.get("metadata") or {},
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -194,8 +280,17 @@ class PortfolioSnapshot:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Serialise all fields. ``holdings_snapshot`` is already a list[dict].""" """Serialise all fields. ``holdings_snapshot`` is already a list[dict]."""
# TODO: implement return {
raise NotImplementedError "snapshot_id": self.snapshot_id,
"portfolio_id": self.portfolio_id,
"snapshot_date": self.snapshot_date,
"total_value": self.total_value,
"cash": self.cash,
"equity_value": self.equity_value,
"num_positions": self.num_positions,
"holdings_snapshot": self.holdings_snapshot,
"metadata": self.metadata,
}
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot": def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot":
@ -203,5 +298,17 @@ class PortfolioSnapshot:
``holdings_snapshot`` is parsed from a JSON string when needed. ``holdings_snapshot`` is parsed from a JSON string when needed.
""" """
# TODO: implement holdings_snapshot = data.get("holdings_snapshot", [])
raise NotImplementedError if isinstance(holdings_snapshot, str):
holdings_snapshot = json.loads(holdings_snapshot)
return cls(
snapshot_id=data["snapshot_id"],
portfolio_id=data["portfolio_id"],
snapshot_date=data["snapshot_date"],
total_value=float(data["total_value"]),
cash=float(data["cash"]),
equity_value=float(data["equity_value"]),
num_positions=int(data["num_positions"]),
holdings_snapshot=holdings_snapshot,
metadata=data.get("metadata") or {},
)

View File

@ -28,9 +28,12 @@ Usage::
from __future__ import annotations from __future__ import annotations
import json
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from tradingagents.portfolio.exceptions import ReportStoreError
class ReportStore: class ReportStore:
"""Filesystem document store for all portfolio-related reports. """Filesystem document store for all portfolio-related reports.
@ -48,8 +51,7 @@ class ReportStore:
Override via the ``PORTFOLIO_DATA_DIR`` env var or Override via the ``PORTFOLIO_DATA_DIR`` env var or
``get_portfolio_config()["data_dir"]``. ``get_portfolio_config()["data_dir"]``.
""" """
# TODO: implement — store Path(base_dir), resolve as needed self._base_dir = Path(base_dir)
raise NotImplementedError
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Internal helpers # Internal helpers
@ -60,8 +62,7 @@ class ReportStore:
Path: ``{base_dir}/daily/{date}/portfolio/`` Path: ``{base_dir}/daily/{date}/portfolio/``
""" """
# TODO: implement return self._base_dir / "daily" / date / "portfolio"
raise NotImplementedError
def _write_json(self, path: Path, data: dict[str, Any]) -> Path: def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
"""Write a dict to a JSON file, creating parent directories as needed. """Write a dict to a JSON file, creating parent directories as needed.
@ -76,8 +77,12 @@ class ReportStore:
Raises: Raises:
ReportStoreError: On filesystem write failure. ReportStoreError: On filesystem write failure.
""" """
# TODO: implement try:
raise NotImplementedError path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
return path
except OSError as exc:
raise ReportStoreError(f"Failed to write {path}: {exc}") from exc
def _read_json(self, path: Path) -> dict[str, Any] | None: def _read_json(self, path: Path) -> dict[str, Any] | None:
"""Read a JSON file, returning None if the file does not exist. """Read a JSON file, returning None if the file does not exist.
@ -85,8 +90,12 @@ class ReportStore:
Raises: Raises:
ReportStoreError: On JSON parse error (file exists but is corrupt). ReportStoreError: On JSON parse error (file exists but is corrupt).
""" """
# TODO: implement if not path.exists():
raise NotImplementedError return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise ReportStoreError(f"Corrupt JSON at {path}: {exc}") from exc
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Macro Scan # Macro Scan
@ -104,13 +113,13 @@ class ReportStore:
Returns: Returns:
Path of the written file. Path of the written file.
""" """
# TODO: implement path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
raise NotImplementedError return self._write_json(path, data)
def load_scan(self, date: str) -> dict[str, Any] | None: def load_scan(self, date: str) -> dict[str, Any] | None:
"""Load macro scan summary. Returns None if the file does not exist.""" """Load macro scan summary. Returns None if the file does not exist."""
# TODO: implement path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
raise NotImplementedError return self._read_json(path)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Per-Ticker Analysis # Per-Ticker Analysis
@ -126,13 +135,13 @@ class ReportStore:
ticker: Ticker symbol (stored as uppercase). ticker: Ticker symbol (stored as uppercase).
data: Analysis output dict. data: Analysis output dict.
""" """
# TODO: implement path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
raise NotImplementedError return self._write_json(path, data)
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None: def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
"""Load per-ticker analysis JSON. Returns None if the file does not exist.""" """Load per-ticker analysis JSON. Returns None if the file does not exist."""
# TODO: implement path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
raise NotImplementedError return self._read_json(path)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Holding Reviews # Holding Reviews
@ -153,13 +162,13 @@ class ReportStore:
ticker: Ticker symbol (stored as uppercase). ticker: Ticker symbol (stored as uppercase).
data: HoldingReviewerAgent output dict. data: HoldingReviewerAgent output dict.
""" """
# TODO: implement path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
raise NotImplementedError return self._write_json(path, data)
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None: def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
"""Load holding review output. Returns None if the file does not exist.""" """Load holding review output. Returns None if the file does not exist."""
# TODO: implement path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
raise NotImplementedError return self._read_json(path)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Risk Metrics # Risk Metrics
@ -180,8 +189,8 @@ class ReportStore:
portfolio_id: UUID of the target portfolio. portfolio_id: UUID of the target portfolio.
data: Risk metrics dict (Sharpe, Sortino, VaR, etc.). data: Risk metrics dict (Sharpe, Sortino, VaR, etc.).
""" """
# TODO: implement path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
raise NotImplementedError return self._write_json(path, data)
def load_risk_metrics( def load_risk_metrics(
self, self,
@ -189,8 +198,8 @@ class ReportStore:
portfolio_id: str, portfolio_id: str,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Load risk metrics. Returns None if the file does not exist.""" """Load risk metrics. Returns None if the file does not exist."""
# TODO: implement path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
raise NotImplementedError return self._read_json(path)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# PM Decisions # PM Decisions
@ -218,8 +227,15 @@ class ReportStore:
Returns: Returns:
Path of the written JSON file. Path of the written JSON file.
""" """
# TODO: implement json_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
raise NotImplementedError self._write_json(json_path, data)
if markdown is not None:
md_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.md"
try:
md_path.write_text(markdown, encoding="utf-8")
except OSError as exc:
raise ReportStoreError(f"Failed to write {md_path}: {exc}") from exc
return json_path
def load_pm_decision( def load_pm_decision(
self, self,
@ -227,8 +243,8 @@ class ReportStore:
portfolio_id: str, portfolio_id: str,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Load PM decision JSON. Returns None if the file does not exist.""" """Load PM decision JSON. Returns None if the file does not exist."""
# TODO: implement path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
raise NotImplementedError return self._read_json(path)
def list_pm_decisions(self, portfolio_id: str) -> list[Path]: def list_pm_decisions(self, portfolio_id: str) -> list[Path]:
"""Return all saved PM decision JSON paths for portfolio_id, newest first. """Return all saved PM decision JSON paths for portfolio_id, newest first.
@ -241,5 +257,5 @@ class ReportStore:
Returns: Returns:
Sorted list of Path objects, newest date first. Sorted list of Path objects, newest date first.
""" """
# TODO: implement pattern = f"daily/*/portfolio/{portfolio_id}_pm_decision.json"
raise NotImplementedError return sorted(self._base_dir.glob(pattern), reverse=True)