feat: implement Portfolio models, ReportStore, and tests; fix SQL constraint
Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
parent
7ea9866d1d
commit
aa4dcdeb80
|
|
@ -17,12 +17,20 @@ Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.portfolio.models import (
|
||||
Holding,
|
||||
Portfolio,
|
||||
PortfolioSnapshot,
|
||||
Trade,
|
||||
)
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skip marker for Supabase integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -51,31 +59,72 @@ def sample_holding_id() -> str:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_portfolio(sample_portfolio_id: str):
|
||||
def sample_portfolio(sample_portfolio_id: str) -> Portfolio:
|
||||
"""Return an unsaved Portfolio instance for testing."""
|
||||
# TODO: implement — construct a Portfolio dataclass with test values
|
||||
raise NotImplementedError
|
||||
return Portfolio(
|
||||
portfolio_id=sample_portfolio_id,
|
||||
name="Test Portfolio",
|
||||
cash=50_000.0,
|
||||
initial_cash=100_000.0,
|
||||
currency="USD",
|
||||
created_at="2026-03-20T00:00:00Z",
|
||||
updated_at="2026-03-20T00:00:00Z",
|
||||
report_path="reports/daily/2026-03-20/portfolio",
|
||||
metadata={"strategy": "test"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_holding(sample_portfolio_id: str, sample_holding_id: str):
|
||||
def sample_holding(sample_portfolio_id: str, sample_holding_id: str) -> Holding:
|
||||
"""Return an unsaved Holding instance for testing."""
|
||||
# TODO: implement — construct a Holding dataclass with test values
|
||||
raise NotImplementedError
|
||||
return Holding(
|
||||
holding_id=sample_holding_id,
|
||||
portfolio_id=sample_portfolio_id,
|
||||
ticker="AAPL",
|
||||
shares=100.0,
|
||||
avg_cost=150.0,
|
||||
sector="Technology",
|
||||
industry="Consumer Electronics",
|
||||
created_at="2026-03-20T00:00:00Z",
|
||||
updated_at="2026-03-20T00:00:00Z",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_trade(sample_portfolio_id: str):
|
||||
def sample_trade(sample_portfolio_id: str) -> Trade:
|
||||
"""Return an unsaved Trade instance for testing."""
|
||||
# TODO: implement — construct a Trade dataclass with test values
|
||||
raise NotImplementedError
|
||||
return Trade(
|
||||
trade_id="33333333-3333-3333-3333-333333333333",
|
||||
portfolio_id=sample_portfolio_id,
|
||||
ticker="AAPL",
|
||||
action="BUY",
|
||||
shares=100.0,
|
||||
price=150.0,
|
||||
total_value=15_000.0,
|
||||
trade_date="2026-03-20T10:00:00Z",
|
||||
rationale="Strong momentum signal",
|
||||
signal_source="scanner",
|
||||
metadata={"confidence": 0.85},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_snapshot(sample_portfolio_id: str):
|
||||
def sample_snapshot(sample_portfolio_id: str) -> PortfolioSnapshot:
|
||||
"""Return an unsaved PortfolioSnapshot instance for testing."""
|
||||
# TODO: implement — construct a PortfolioSnapshot dataclass with test values
|
||||
raise NotImplementedError
|
||||
return PortfolioSnapshot(
|
||||
snapshot_id="44444444-4444-4444-4444-444444444444",
|
||||
portfolio_id=sample_portfolio_id,
|
||||
snapshot_date="2026-03-20",
|
||||
total_value=115_000.0,
|
||||
cash=50_000.0,
|
||||
equity_value=65_000.0,
|
||||
num_positions=2,
|
||||
holdings_snapshot=[
|
||||
{"ticker": "AAPL", "shares": 100.0, "avg_cost": 150.0},
|
||||
{"ticker": "MSFT", "shares": 50.0, "avg_cost": 300.0},
|
||||
],
|
||||
metadata={"note": "end of day snapshot"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -92,10 +141,9 @@ def tmp_reports(tmp_path: Path) -> Path:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def report_store(tmp_reports: Path):
|
||||
def report_store(tmp_reports: Path) -> ReportStore:
|
||||
"""ReportStore instance backed by a temporary directory."""
|
||||
# TODO: implement — return ReportStore(base_dir=tmp_reports)
|
||||
raise NotImplementedError
|
||||
return ReportStore(base_dir=tmp_reports)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -104,7 +152,6 @@ def report_store(tmp_reports: Path):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_supabase_client():
|
||||
def mock_supabase_client() -> MagicMock:
|
||||
"""MagicMock of SupabaseClient for unit tests that don't hit the DB."""
|
||||
# TODO: implement — return MagicMock(spec=SupabaseClient)
|
||||
raise NotImplementedError
|
||||
return MagicMock(spec=SupabaseClient)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,13 @@ from __future__ import annotations
|
|||
|
||||
import pytest
|
||||
|
||||
from tradingagents.portfolio.models import (
|
||||
Holding,
|
||||
Portfolio,
|
||||
PortfolioSnapshot,
|
||||
Trade,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Portfolio round-trip
|
||||
|
|
@ -24,19 +31,41 @@ import pytest
|
|||
|
||||
def test_portfolio_to_dict_round_trip(sample_portfolio):
|
||||
"""Portfolio.to_dict() -> Portfolio.from_dict() must be lossless."""
|
||||
# TODO: implement
|
||||
# d = sample_portfolio.to_dict()
|
||||
# restored = Portfolio.from_dict(d)
|
||||
# assert restored.portfolio_id == sample_portfolio.portfolio_id
|
||||
# assert restored.cash == sample_portfolio.cash
|
||||
# ... all stored fields
|
||||
raise NotImplementedError
|
||||
d = sample_portfolio.to_dict()
|
||||
restored = Portfolio.from_dict(d)
|
||||
assert restored.portfolio_id == sample_portfolio.portfolio_id
|
||||
assert restored.name == sample_portfolio.name
|
||||
assert restored.cash == sample_portfolio.cash
|
||||
assert restored.initial_cash == sample_portfolio.initial_cash
|
||||
assert restored.currency == sample_portfolio.currency
|
||||
assert restored.created_at == sample_portfolio.created_at
|
||||
assert restored.updated_at == sample_portfolio.updated_at
|
||||
assert restored.report_path == sample_portfolio.report_path
|
||||
assert restored.metadata == sample_portfolio.metadata
|
||||
|
||||
|
||||
def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
|
||||
"""to_dict() must not include computed fields (total_value, equity_value, cash_pct)."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
d = sample_portfolio.to_dict()
|
||||
assert "total_value" not in d
|
||||
assert "equity_value" not in d
|
||||
assert "cash_pct" not in d
|
||||
|
||||
|
||||
def test_portfolio_from_dict_defaults_optional_fields():
|
||||
"""from_dict() must tolerate missing optional fields."""
|
||||
minimal = {
|
||||
"portfolio_id": "pid-1",
|
||||
"name": "Minimal",
|
||||
"cash": 1000.0,
|
||||
"initial_cash": 1000.0,
|
||||
}
|
||||
p = Portfolio.from_dict(minimal)
|
||||
assert p.currency == "USD"
|
||||
assert p.created_at == ""
|
||||
assert p.updated_at == ""
|
||||
assert p.report_path is None
|
||||
assert p.metadata == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -46,14 +75,23 @@ def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
|
|||
|
||||
def test_holding_to_dict_round_trip(sample_holding):
|
||||
"""Holding.to_dict() -> Holding.from_dict() must be lossless."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
d = sample_holding.to_dict()
|
||||
restored = Holding.from_dict(d)
|
||||
assert restored.holding_id == sample_holding.holding_id
|
||||
assert restored.portfolio_id == sample_holding.portfolio_id
|
||||
assert restored.ticker == sample_holding.ticker
|
||||
assert restored.shares == sample_holding.shares
|
||||
assert restored.avg_cost == sample_holding.avg_cost
|
||||
assert restored.sector == sample_holding.sector
|
||||
assert restored.industry == sample_holding.industry
|
||||
|
||||
|
||||
def test_holding_to_dict_excludes_runtime_fields(sample_holding):
|
||||
"""to_dict() must not include current_price, current_value, weight, etc."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
d = sample_holding.to_dict()
|
||||
for field in ("current_price", "current_value", "cost_basis",
|
||||
"unrealized_pnl", "unrealized_pnl_pct", "weight"):
|
||||
assert field not in d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -63,8 +101,19 @@ def test_holding_to_dict_excludes_runtime_fields(sample_holding):
|
|||
|
||||
def test_trade_to_dict_round_trip(sample_trade):
|
||||
"""Trade.to_dict() -> Trade.from_dict() must be lossless."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
d = sample_trade.to_dict()
|
||||
restored = Trade.from_dict(d)
|
||||
assert restored.trade_id == sample_trade.trade_id
|
||||
assert restored.portfolio_id == sample_trade.portfolio_id
|
||||
assert restored.ticker == sample_trade.ticker
|
||||
assert restored.action == sample_trade.action
|
||||
assert restored.shares == sample_trade.shares
|
||||
assert restored.price == sample_trade.price
|
||||
assert restored.total_value == sample_trade.total_value
|
||||
assert restored.trade_date == sample_trade.trade_date
|
||||
assert restored.rationale == sample_trade.rationale
|
||||
assert restored.signal_source == sample_trade.signal_source
|
||||
assert restored.metadata == sample_trade.metadata
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -74,8 +123,35 @@ def test_trade_to_dict_round_trip(sample_trade):
|
|||
|
||||
def test_snapshot_to_dict_round_trip(sample_snapshot):
|
||||
"""PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
d = sample_snapshot.to_dict()
|
||||
restored = PortfolioSnapshot.from_dict(d)
|
||||
assert restored.snapshot_id == sample_snapshot.snapshot_id
|
||||
assert restored.portfolio_id == sample_snapshot.portfolio_id
|
||||
assert restored.snapshot_date == sample_snapshot.snapshot_date
|
||||
assert restored.total_value == sample_snapshot.total_value
|
||||
assert restored.cash == sample_snapshot.cash
|
||||
assert restored.equity_value == sample_snapshot.equity_value
|
||||
assert restored.num_positions == sample_snapshot.num_positions
|
||||
assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot
|
||||
assert restored.metadata == sample_snapshot.metadata
|
||||
|
||||
|
||||
def test_snapshot_from_dict_parses_holdings_snapshot_json_string():
|
||||
"""from_dict() must parse holdings_snapshot when it arrives as a JSON string."""
|
||||
import json
|
||||
holdings = [{"ticker": "AAPL", "shares": 10.0}]
|
||||
data = {
|
||||
"snapshot_id": "snap-1",
|
||||
"portfolio_id": "pid-1",
|
||||
"snapshot_date": "2026-03-20",
|
||||
"total_value": 110_000.0,
|
||||
"cash": 10_000.0,
|
||||
"equity_value": 100_000.0,
|
||||
"num_positions": 1,
|
||||
"holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase
|
||||
}
|
||||
snap = PortfolioSnapshot.from_dict(data)
|
||||
assert snap.holdings_snapshot == holdings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -85,34 +161,50 @@ def test_snapshot_to_dict_round_trip(sample_snapshot):
|
|||
|
||||
def test_holding_enrich_computes_current_value(sample_holding):
|
||||
"""enrich() must set current_value = current_price * shares."""
|
||||
# TODO: implement
|
||||
# sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
# assert sample_holding.current_value == 200.0 * sample_holding.shares
|
||||
raise NotImplementedError
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
assert sample_holding.current_value == 200.0 * sample_holding.shares
|
||||
|
||||
|
||||
def test_holding_enrich_computes_unrealized_pnl(sample_holding):
|
||||
"""enrich() must set unrealized_pnl = current_value - cost_basis."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
expected_cost_basis = sample_holding.avg_cost * sample_holding.shares
|
||||
expected_pnl = sample_holding.current_value - expected_cost_basis
|
||||
assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl)
|
||||
|
||||
|
||||
def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding):
|
||||
"""enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis."""
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
cost_basis = sample_holding.avg_cost * sample_holding.shares
|
||||
expected_pct = sample_holding.unrealized_pnl / cost_basis
|
||||
assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct)
|
||||
|
||||
|
||||
def test_holding_enrich_computes_weight(sample_holding):
|
||||
"""enrich() must set weight = current_value / portfolio_total_value."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
expected_weight = sample_holding.current_value / 100_000.0
|
||||
assert sample_holding.weight == pytest.approx(expected_weight)
|
||||
|
||||
|
||||
def test_holding_enrich_returns_self(sample_holding):
|
||||
"""enrich() must return self for chaining."""
|
||||
result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
assert result is sample_holding
|
||||
|
||||
|
||||
def test_holding_enrich_handles_zero_cost(sample_holding):
|
||||
"""When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError)."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
sample_holding.avg_cost = 0.0
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||
assert sample_holding.unrealized_pnl_pct == 0.0
|
||||
|
||||
|
||||
def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
|
||||
"""When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError)."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0)
|
||||
assert sample_holding.weight == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -122,11 +214,36 @@ def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
|
|||
|
||||
def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding):
|
||||
"""Portfolio.enrich() must compute total_value = cash + sum(holding.current_value)."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
||||
sample_portfolio.enrich([sample_holding])
|
||||
expected_equity = 200.0 * sample_holding.shares
|
||||
assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity)
|
||||
|
||||
|
||||
def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding):
|
||||
"""Portfolio.enrich() must set equity_value = sum(holding.current_value)."""
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
||||
sample_portfolio.enrich([sample_holding])
|
||||
assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares)
|
||||
|
||||
|
||||
def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding):
|
||||
"""Portfolio.enrich() must compute cash_pct = cash / total_value."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
||||
sample_portfolio.enrich([sample_holding])
|
||||
expected_pct = sample_portfolio.cash / sample_portfolio.total_value
|
||||
assert sample_portfolio.cash_pct == pytest.approx(expected_pct)
|
||||
|
||||
|
||||
def test_portfolio_enrich_returns_self(sample_portfolio):
|
||||
"""enrich() must return self for chaining."""
|
||||
result = sample_portfolio.enrich([])
|
||||
assert result is sample_portfolio
|
||||
|
||||
|
||||
def test_portfolio_enrich_no_holdings(sample_portfolio):
|
||||
"""Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash."""
|
||||
sample_portfolio.enrich([])
|
||||
assert sample_portfolio.equity_value == 0.0
|
||||
assert sample_portfolio.total_value == sample_portfolio.cash
|
||||
assert sample_portfolio.cash_pct == 1.0
|
||||
|
|
|
|||
|
|
@ -12,10 +12,14 @@ Run::
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tradingagents.portfolio.exceptions import ReportStoreError
|
||||
from tradingagents.portfolio.report_store import ReportStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro scan
|
||||
|
|
@ -24,19 +28,17 @@ import pytest
|
|||
|
||||
def test_save_and_load_scan(report_store, tmp_reports):
|
||||
"""save_scan() then load_scan() must return the original data."""
|
||||
# TODO: implement
|
||||
# data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
|
||||
# path = report_store.save_scan("2026-03-20", data)
|
||||
# assert path.exists()
|
||||
# loaded = report_store.load_scan("2026-03-20")
|
||||
# assert loaded == data
|
||||
raise NotImplementedError
|
||||
data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
|
||||
path = report_store.save_scan("2026-03-20", data)
|
||||
assert path.exists()
|
||||
loaded = report_store.load_scan("2026-03-20")
|
||||
assert loaded == data
|
||||
|
||||
|
||||
def test_load_scan_returns_none_for_missing_file(report_store):
|
||||
"""load_scan() must return None when the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
result = report_store.load_scan("1900-01-01")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -46,14 +48,21 @@ def test_load_scan_returns_none_for_missing_file(report_store):
|
|||
|
||||
def test_save_and_load_analysis(report_store):
|
||||
"""save_analysis() then load_analysis() must return the original data."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
data = {"ticker": "AAPL", "recommendation": "BUY", "score": 0.92}
|
||||
report_store.save_analysis("2026-03-20", "AAPL", data)
|
||||
loaded = report_store.load_analysis("2026-03-20", "AAPL")
|
||||
assert loaded == data
|
||||
|
||||
|
||||
def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
|
||||
"""Ticker symbol must be stored as uppercase in the directory path."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
data = {"ticker": "aapl"}
|
||||
report_store.save_analysis("2026-03-20", "aapl", data)
|
||||
expected = tmp_reports / "daily" / "2026-03-20" / "AAPL" / "complete_report.json"
|
||||
assert expected.exists()
|
||||
# load with lowercase should still work
|
||||
loaded = report_store.load_analysis("2026-03-20", "aapl")
|
||||
assert loaded == data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -63,14 +72,16 @@ def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
|
|||
|
||||
def test_save_and_load_holding_review(report_store):
|
||||
"""save_holding_review() then load_holding_review() must round-trip."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
data = {"ticker": "MSFT", "verdict": "HOLD", "price_target": 420.0}
|
||||
report_store.save_holding_review("2026-03-20", "MSFT", data)
|
||||
loaded = report_store.load_holding_review("2026-03-20", "MSFT")
|
||||
assert loaded == data
|
||||
|
||||
|
||||
def test_load_holding_review_returns_none_for_missing(report_store):
|
||||
"""load_holding_review() must return None when the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
result = report_store.load_holding_review("1900-01-01", "ZZZZ")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -80,8 +91,10 @@ def test_load_holding_review_returns_none_for_missing(report_store):
|
|||
|
||||
def test_save_and_load_risk_metrics(report_store):
|
||||
"""save_risk_metrics() then load_risk_metrics() must round-trip."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
data = {"sharpe": 1.35, "sortino": 1.8, "max_drawdown": -0.12}
|
||||
report_store.save_risk_metrics("2026-03-20", "pid-123", data)
|
||||
loaded = report_store.load_risk_metrics("2026-03-20", "pid-123")
|
||||
assert loaded == data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -91,37 +104,46 @@ def test_save_and_load_risk_metrics(report_store):
|
|||
|
||||
def test_save_and_load_pm_decision_json(report_store):
|
||||
"""save_pm_decision() then load_pm_decision() must round-trip JSON."""
|
||||
# TODO: implement
|
||||
# decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
|
||||
# report_store.save_pm_decision("2026-03-20", "pid-123", decision)
|
||||
# loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
|
||||
# assert loaded == decision
|
||||
raise NotImplementedError
|
||||
decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
|
||||
report_store.save_pm_decision("2026-03-20", "pid-123", decision)
|
||||
loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
|
||||
assert loaded == decision
|
||||
|
||||
|
||||
def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports):
|
||||
"""When markdown is passed to save_pm_decision(), .md file must be written."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
decision = {"sells": [], "buys": []}
|
||||
md_text = "# Decision\n\nHold everything."
|
||||
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=md_text)
|
||||
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
|
||||
assert md_path.exists()
|
||||
assert md_path.read_text(encoding="utf-8") == md_text
|
||||
|
||||
|
||||
def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports):
|
||||
"""When markdown=None, no .md file should be written."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
decision = {"sells": [], "buys": []}
|
||||
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=None)
|
||||
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
|
||||
assert not md_path.exists()
|
||||
|
||||
|
||||
def test_load_pm_decision_returns_none_for_missing(report_store):
|
||||
"""load_pm_decision() must return None when the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
result = report_store.load_pm_decision("1900-01-01", "pid-none")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_list_pm_decisions(report_store):
|
||||
"""list_pm_decisions() must return all saved decision paths, newest first."""
|
||||
# TODO: implement
|
||||
# Save decisions for multiple dates, verify order
|
||||
raise NotImplementedError
|
||||
dates = ["2026-03-18", "2026-03-19", "2026-03-20"]
|
||||
for d in dates:
|
||||
report_store.save_pm_decision(d, "pid-abc", {"date": d})
|
||||
paths = report_store.list_pm_decisions("pid-abc")
|
||||
assert len(paths) == 3
|
||||
# Sorted newest first by ISO date string ordering
|
||||
date_parts = [p.parent.parent.name for p in paths]
|
||||
assert date_parts == sorted(dates, reverse=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -131,15 +153,24 @@ def test_list_pm_decisions(report_store):
|
|||
|
||||
def test_directories_created_on_write(report_store, tmp_reports):
|
||||
"""Directories must be created automatically on first write."""
|
||||
# TODO: implement
|
||||
# assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists()
|
||||
# report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
|
||||
# assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir()
|
||||
raise NotImplementedError
|
||||
target_dir = tmp_reports / "daily" / "2026-03-20" / "portfolio"
|
||||
assert not target_dir.exists()
|
||||
report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
|
||||
assert target_dir.is_dir()
|
||||
|
||||
|
||||
def test_json_formatted_with_indent(report_store, tmp_reports):
|
||||
"""Written JSON files must use indent=2 for human readability."""
|
||||
# TODO: implement
|
||||
# Write a file, read the raw bytes, verify indentation
|
||||
raise NotImplementedError
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
path = report_store.save_scan("2026-03-20", data)
|
||||
raw = path.read_text(encoding="utf-8")
|
||||
# indent=2 means lines like ' "key": ...'
|
||||
assert ' "key"' in raw
|
||||
|
||||
|
||||
def test_read_json_raises_on_corrupt_file(report_store, tmp_reports):
|
||||
"""_read_json must raise ReportStoreError for corrupt JSON."""
|
||||
corrupt = tmp_reports / "corrupt.json"
|
||||
corrupt.write_text("not valid json{{{", encoding="utf-8")
|
||||
with pytest.raises(ReportStoreError):
|
||||
report_store._read_json(corrupt)
|
||||
|
|
|
|||
|
|
@ -65,16 +65,14 @@ CREATE TABLE IF NOT EXISTS trades (
|
|||
trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE,
|
||||
ticker TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')),
|
||||
shares NUMERIC(18,6) NOT NULL CHECK (shares > 0),
|
||||
price NUMERIC(18,4) NOT NULL CHECK (price > 0),
|
||||
total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0),
|
||||
trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
rationale TEXT, -- PM agent rationale for this trade
|
||||
signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent'
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL'))
|
||||
metadata JSONB NOT NULL DEFAULT '{}'
|
||||
);
|
||||
|
||||
COMMENT ON TABLE trades IS
|
||||
|
|
|
|||
|
|
@ -6,11 +6,26 @@ All models are Python ``dataclass`` types with:
|
|||
- ``from_dict()`` class method for deserialisation
|
||||
- ``enrich()`` for attaching runtime-computed fields
|
||||
|
||||
**float vs Decimal** — monetary fields (cash, price, shares, etc.) use plain
|
||||
``float`` throughout. Rationale:
|
||||
|
||||
1. This is **mock trading only** — no real money changes hands. The cost of a
|
||||
subtle floating-point rounding error is zero.
|
||||
2. All upstream data sources (yfinance, Alpha Vantage, Finnhub) return ``float``
|
||||
already. Converting to ``Decimal`` at the boundary would require a custom
|
||||
JSON encoder *and* decoder everywhere, for no practical gain.
|
||||
3. ``json.dumps`` serialises ``float`` natively; ``Decimal`` raises
|
||||
``TypeError`` without a custom encoder.
|
||||
4. If this ever becomes real-money trading, replace ``float`` with
|
||||
``decimal.Decimal`` and add a ``DecimalEncoder`` — the interface is
|
||||
identical and the change is localised to this file.
|
||||
|
||||
See ``docs/portfolio/02_data_models.md`` for full field specifications.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -48,8 +63,17 @@ class Portfolio:
|
|||
|
||||
Runtime-computed fields are excluded.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return {
|
||||
"portfolio_id": self.portfolio_id,
|
||||
"name": self.name,
|
||||
"cash": self.cash,
|
||||
"initial_cash": self.initial_cash,
|
||||
"currency": self.currency,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"report_path": self.report_path,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "Portfolio":
|
||||
|
|
@ -57,8 +81,17 @@ class Portfolio:
|
|||
|
||||
Missing optional fields default gracefully. Extra keys are ignored.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return cls(
|
||||
portfolio_id=data["portfolio_id"],
|
||||
name=data["name"],
|
||||
cash=float(data["cash"]),
|
||||
initial_cash=float(data["initial_cash"]),
|
||||
currency=data.get("currency", "USD"),
|
||||
created_at=data.get("created_at", ""),
|
||||
updated_at=data.get("updated_at", ""),
|
||||
report_path=data.get("report_path"),
|
||||
metadata=data.get("metadata") or {},
|
||||
)
|
||||
|
||||
def enrich(self, holdings: list["Holding"]) -> "Portfolio":
|
||||
"""Compute total_value, equity_value, cash_pct from holdings.
|
||||
|
|
@ -69,8 +102,12 @@ class Portfolio:
|
|||
holdings: List of Holding objects with current_value populated
|
||||
(i.e., ``holding.enrich()`` already called).
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
self.equity_value = sum(
|
||||
h.current_value for h in holdings if h.current_value is not None
|
||||
)
|
||||
self.total_value = self.cash + self.equity_value
|
||||
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -106,14 +143,32 @@ class Holding:
|
|||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialise stored fields only (runtime-computed fields excluded)."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return {
|
||||
"holding_id": self.holding_id,
|
||||
"portfolio_id": self.portfolio_id,
|
||||
"ticker": self.ticker,
|
||||
"shares": self.shares,
|
||||
"avg_cost": self.avg_cost,
|
||||
"sector": self.sector,
|
||||
"industry": self.industry,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "Holding":
|
||||
"""Deserialise from a DB row or JSON dict."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return cls(
|
||||
holding_id=data["holding_id"],
|
||||
portfolio_id=data["portfolio_id"],
|
||||
ticker=data["ticker"],
|
||||
shares=float(data["shares"]),
|
||||
avg_cost=float(data["avg_cost"]),
|
||||
sector=data.get("sector"),
|
||||
industry=data.get("industry"),
|
||||
created_at=data.get("created_at", ""),
|
||||
updated_at=data.get("updated_at", ""),
|
||||
)
|
||||
|
||||
def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding":
|
||||
"""Populate runtime-computed fields in-place and return self.
|
||||
|
|
@ -129,8 +184,17 @@ class Holding:
|
|||
current_price: Latest market price for this ticker.
|
||||
portfolio_total_value: Total portfolio value (cash + equity).
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
self.current_price = current_price
|
||||
self.current_value = current_price * self.shares
|
||||
self.cost_basis = self.avg_cost * self.shares
|
||||
self.unrealized_pnl = self.current_value - self.cost_basis
|
||||
self.unrealized_pnl_pct = (
|
||||
self.unrealized_pnl / self.cost_basis if self.cost_basis != 0.0 else 0.0
|
||||
)
|
||||
self.weight = (
|
||||
self.current_value / portfolio_total_value if portfolio_total_value != 0.0 else 0.0
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -159,14 +223,36 @@ class Trade:
|
|||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialise all fields."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return {
|
||||
"trade_id": self.trade_id,
|
||||
"portfolio_id": self.portfolio_id,
|
||||
"ticker": self.ticker,
|
||||
"action": self.action,
|
||||
"shares": self.shares,
|
||||
"price": self.price,
|
||||
"total_value": self.total_value,
|
||||
"trade_date": self.trade_date,
|
||||
"rationale": self.rationale,
|
||||
"signal_source": self.signal_source,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "Trade":
|
||||
"""Deserialise from a DB row or JSON dict."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return cls(
|
||||
trade_id=data["trade_id"],
|
||||
portfolio_id=data["portfolio_id"],
|
||||
ticker=data["ticker"],
|
||||
action=data["action"],
|
||||
shares=float(data["shares"]),
|
||||
price=float(data["price"]),
|
||||
total_value=float(data["total_value"]),
|
||||
trade_date=data.get("trade_date", ""),
|
||||
rationale=data.get("rationale"),
|
||||
signal_source=data.get("signal_source"),
|
||||
metadata=data.get("metadata") or {},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -194,8 +280,17 @@ class PortfolioSnapshot:
|
|||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialise all fields. ``holdings_snapshot`` is already a list[dict]."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return {
|
||||
"snapshot_id": self.snapshot_id,
|
||||
"portfolio_id": self.portfolio_id,
|
||||
"snapshot_date": self.snapshot_date,
|
||||
"total_value": self.total_value,
|
||||
"cash": self.cash,
|
||||
"equity_value": self.equity_value,
|
||||
"num_positions": self.num_positions,
|
||||
"holdings_snapshot": self.holdings_snapshot,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot":
|
||||
|
|
@ -203,5 +298,17 @@ class PortfolioSnapshot:
|
|||
|
||||
``holdings_snapshot`` is parsed from a JSON string when needed.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
holdings_snapshot = data.get("holdings_snapshot", [])
|
||||
if isinstance(holdings_snapshot, str):
|
||||
holdings_snapshot = json.loads(holdings_snapshot)
|
||||
return cls(
|
||||
snapshot_id=data["snapshot_id"],
|
||||
portfolio_id=data["portfolio_id"],
|
||||
snapshot_date=data["snapshot_date"],
|
||||
total_value=float(data["total_value"]),
|
||||
cash=float(data["cash"]),
|
||||
equity_value=float(data["equity_value"]),
|
||||
num_positions=int(data["num_positions"]),
|
||||
holdings_snapshot=holdings_snapshot,
|
||||
metadata=data.get("metadata") or {},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,9 +28,12 @@ Usage::
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from tradingagents.portfolio.exceptions import ReportStoreError
|
||||
|
||||
|
||||
class ReportStore:
|
||||
"""Filesystem document store for all portfolio-related reports.
|
||||
|
|
@ -48,8 +51,7 @@ class ReportStore:
|
|||
Override via the ``PORTFOLIO_DATA_DIR`` env var or
|
||||
``get_portfolio_config()["data_dir"]``.
|
||||
"""
|
||||
# TODO: implement — store Path(base_dir), resolve as needed
|
||||
raise NotImplementedError
|
||||
self._base_dir = Path(base_dir)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
|
@ -60,8 +62,7 @@ class ReportStore:
|
|||
|
||||
Path: ``{base_dir}/daily/{date}/portfolio/``
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
return self._base_dir / "daily" / date / "portfolio"
|
||||
|
||||
def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
|
||||
"""Write a dict to a JSON file, creating parent directories as needed.
|
||||
|
|
@ -76,8 +77,12 @@ class ReportStore:
|
|||
Raises:
|
||||
ReportStoreError: On filesystem write failure.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
return path
|
||||
except OSError as exc:
|
||||
raise ReportStoreError(f"Failed to write {path}: {exc}") from exc
|
||||
|
||||
def _read_json(self, path: Path) -> dict[str, Any] | None:
|
||||
"""Read a JSON file, returning None if the file does not exist.
|
||||
|
|
@ -85,8 +90,12 @@ class ReportStore:
|
|||
Raises:
|
||||
ReportStoreError: On JSON parse error (file exists but is corrupt).
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ReportStoreError(f"Corrupt JSON at {path}: {exc}") from exc
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Macro Scan
|
||||
|
|
@ -104,13 +113,13 @@ class ReportStore:
|
|||
Returns:
|
||||
Path of the written file.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
|
||||
return self._write_json(path, data)
|
||||
|
||||
def load_scan(self, date: str) -> dict[str, Any] | None:
|
||||
"""Load macro scan summary. Returns None if the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
|
||||
return self._read_json(path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-Ticker Analysis
|
||||
|
|
@ -126,13 +135,13 @@ class ReportStore:
|
|||
ticker: Ticker symbol (stored as uppercase).
|
||||
data: Analysis output dict.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
|
||||
return self._write_json(path, data)
|
||||
|
||||
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||
"""Load per-ticker analysis JSON. Returns None if the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
|
||||
return self._read_json(path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Holding Reviews
|
||||
|
|
@ -153,13 +162,13 @@ class ReportStore:
|
|||
ticker: Ticker symbol (stored as uppercase).
|
||||
data: HoldingReviewerAgent output dict.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
|
||||
return self._write_json(path, data)
|
||||
|
||||
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||
"""Load holding review output. Returns None if the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
|
||||
return self._read_json(path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Risk Metrics
|
||||
|
|
@ -180,8 +189,8 @@ class ReportStore:
|
|||
portfolio_id: UUID of the target portfolio.
|
||||
data: Risk metrics dict (Sharpe, Sortino, VaR, etc.).
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
|
||||
return self._write_json(path, data)
|
||||
|
||||
def load_risk_metrics(
|
||||
self,
|
||||
|
|
@ -189,8 +198,8 @@ class ReportStore:
|
|||
portfolio_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Load risk metrics. Returns None if the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
|
||||
return self._read_json(path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PM Decisions
|
||||
|
|
@ -218,8 +227,15 @@ class ReportStore:
|
|||
Returns:
|
||||
Path of the written JSON file.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
json_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
|
||||
self._write_json(json_path, data)
|
||||
if markdown is not None:
|
||||
md_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.md"
|
||||
try:
|
||||
md_path.write_text(markdown, encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise ReportStoreError(f"Failed to write {md_path}: {exc}") from exc
|
||||
return json_path
|
||||
|
||||
def load_pm_decision(
|
||||
self,
|
||||
|
|
@ -227,8 +243,8 @@ class ReportStore:
|
|||
portfolio_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Load PM decision JSON. Returns None if the file does not exist."""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
|
||||
return self._read_json(path)
|
||||
|
||||
def list_pm_decisions(self, portfolio_id: str) -> list[Path]:
|
||||
"""Return all saved PM decision JSON paths for portfolio_id, newest first.
|
||||
|
|
@ -241,5 +257,5 @@ class ReportStore:
|
|||
Returns:
|
||||
Sorted list of Path objects, newest date first.
|
||||
"""
|
||||
# TODO: implement
|
||||
raise NotImplementedError
|
||||
pattern = f"daily/*/portfolio/{portfolio_id}_pm_decision.json"
|
||||
return sorted(self._base_dir.glob(pattern), reverse=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue