feat: implement Portfolio models, ReportStore, and tests; fix SQL constraint
Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
parent
7ea9866d1d
commit
aa4dcdeb80
|
|
@ -17,12 +17,20 @@ Supabase integration tests use ``pytest.mark.skipif`` to auto-skip when
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.portfolio.models import (
|
||||||
|
Holding,
|
||||||
|
Portfolio,
|
||||||
|
PortfolioSnapshot,
|
||||||
|
Trade,
|
||||||
|
)
|
||||||
|
from tradingagents.portfolio.report_store import ReportStore
|
||||||
|
from tradingagents.portfolio.supabase_client import SupabaseClient
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Skip marker for Supabase integration tests
|
# Skip marker for Supabase integration tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -51,31 +59,72 @@ def sample_holding_id() -> str:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_portfolio(sample_portfolio_id: str):
|
def sample_portfolio(sample_portfolio_id: str) -> Portfolio:
|
||||||
"""Return an unsaved Portfolio instance for testing."""
|
"""Return an unsaved Portfolio instance for testing."""
|
||||||
# TODO: implement — construct a Portfolio dataclass with test values
|
return Portfolio(
|
||||||
raise NotImplementedError
|
portfolio_id=sample_portfolio_id,
|
||||||
|
name="Test Portfolio",
|
||||||
|
cash=50_000.0,
|
||||||
|
initial_cash=100_000.0,
|
||||||
|
currency="USD",
|
||||||
|
created_at="2026-03-20T00:00:00Z",
|
||||||
|
updated_at="2026-03-20T00:00:00Z",
|
||||||
|
report_path="reports/daily/2026-03-20/portfolio",
|
||||||
|
metadata={"strategy": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_holding(sample_portfolio_id: str, sample_holding_id: str):
|
def sample_holding(sample_portfolio_id: str, sample_holding_id: str) -> Holding:
|
||||||
"""Return an unsaved Holding instance for testing."""
|
"""Return an unsaved Holding instance for testing."""
|
||||||
# TODO: implement — construct a Holding dataclass with test values
|
return Holding(
|
||||||
raise NotImplementedError
|
holding_id=sample_holding_id,
|
||||||
|
portfolio_id=sample_portfolio_id,
|
||||||
|
ticker="AAPL",
|
||||||
|
shares=100.0,
|
||||||
|
avg_cost=150.0,
|
||||||
|
sector="Technology",
|
||||||
|
industry="Consumer Electronics",
|
||||||
|
created_at="2026-03-20T00:00:00Z",
|
||||||
|
updated_at="2026-03-20T00:00:00Z",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_trade(sample_portfolio_id: str):
|
def sample_trade(sample_portfolio_id: str) -> Trade:
|
||||||
"""Return an unsaved Trade instance for testing."""
|
"""Return an unsaved Trade instance for testing."""
|
||||||
# TODO: implement — construct a Trade dataclass with test values
|
return Trade(
|
||||||
raise NotImplementedError
|
trade_id="33333333-3333-3333-3333-333333333333",
|
||||||
|
portfolio_id=sample_portfolio_id,
|
||||||
|
ticker="AAPL",
|
||||||
|
action="BUY",
|
||||||
|
shares=100.0,
|
||||||
|
price=150.0,
|
||||||
|
total_value=15_000.0,
|
||||||
|
trade_date="2026-03-20T10:00:00Z",
|
||||||
|
rationale="Strong momentum signal",
|
||||||
|
signal_source="scanner",
|
||||||
|
metadata={"confidence": 0.85},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_snapshot(sample_portfolio_id: str):
|
def sample_snapshot(sample_portfolio_id: str) -> PortfolioSnapshot:
|
||||||
"""Return an unsaved PortfolioSnapshot instance for testing."""
|
"""Return an unsaved PortfolioSnapshot instance for testing."""
|
||||||
# TODO: implement — construct a PortfolioSnapshot dataclass with test values
|
return PortfolioSnapshot(
|
||||||
raise NotImplementedError
|
snapshot_id="44444444-4444-4444-4444-444444444444",
|
||||||
|
portfolio_id=sample_portfolio_id,
|
||||||
|
snapshot_date="2026-03-20",
|
||||||
|
total_value=115_000.0,
|
||||||
|
cash=50_000.0,
|
||||||
|
equity_value=65_000.0,
|
||||||
|
num_positions=2,
|
||||||
|
holdings_snapshot=[
|
||||||
|
{"ticker": "AAPL", "shares": 100.0, "avg_cost": 150.0},
|
||||||
|
{"ticker": "MSFT", "shares": 50.0, "avg_cost": 300.0},
|
||||||
|
],
|
||||||
|
metadata={"note": "end of day snapshot"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -92,10 +141,9 @@ def tmp_reports(tmp_path: Path) -> Path:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def report_store(tmp_reports: Path):
|
def report_store(tmp_reports: Path) -> ReportStore:
|
||||||
"""ReportStore instance backed by a temporary directory."""
|
"""ReportStore instance backed by a temporary directory."""
|
||||||
# TODO: implement — return ReportStore(base_dir=tmp_reports)
|
return ReportStore(base_dir=tmp_reports)
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -104,7 +152,6 @@ def report_store(tmp_reports: Path):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_supabase_client():
|
def mock_supabase_client() -> MagicMock:
|
||||||
"""MagicMock of SupabaseClient for unit tests that don't hit the DB."""
|
"""MagicMock of SupabaseClient for unit tests that don't hit the DB."""
|
||||||
# TODO: implement — return MagicMock(spec=SupabaseClient)
|
return MagicMock(spec=SupabaseClient)
|
||||||
raise NotImplementedError
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,13 @@ from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.portfolio.models import (
|
||||||
|
Holding,
|
||||||
|
Portfolio,
|
||||||
|
PortfolioSnapshot,
|
||||||
|
Trade,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Portfolio round-trip
|
# Portfolio round-trip
|
||||||
|
|
@ -24,19 +31,41 @@ import pytest
|
||||||
|
|
||||||
def test_portfolio_to_dict_round_trip(sample_portfolio):
|
def test_portfolio_to_dict_round_trip(sample_portfolio):
|
||||||
"""Portfolio.to_dict() -> Portfolio.from_dict() must be lossless."""
|
"""Portfolio.to_dict() -> Portfolio.from_dict() must be lossless."""
|
||||||
# TODO: implement
|
d = sample_portfolio.to_dict()
|
||||||
# d = sample_portfolio.to_dict()
|
restored = Portfolio.from_dict(d)
|
||||||
# restored = Portfolio.from_dict(d)
|
assert restored.portfolio_id == sample_portfolio.portfolio_id
|
||||||
# assert restored.portfolio_id == sample_portfolio.portfolio_id
|
assert restored.name == sample_portfolio.name
|
||||||
# assert restored.cash == sample_portfolio.cash
|
assert restored.cash == sample_portfolio.cash
|
||||||
# ... all stored fields
|
assert restored.initial_cash == sample_portfolio.initial_cash
|
||||||
raise NotImplementedError
|
assert restored.currency == sample_portfolio.currency
|
||||||
|
assert restored.created_at == sample_portfolio.created_at
|
||||||
|
assert restored.updated_at == sample_portfolio.updated_at
|
||||||
|
assert restored.report_path == sample_portfolio.report_path
|
||||||
|
assert restored.metadata == sample_portfolio.metadata
|
||||||
|
|
||||||
|
|
||||||
def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
|
def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
|
||||||
"""to_dict() must not include computed fields (total_value, equity_value, cash_pct)."""
|
"""to_dict() must not include computed fields (total_value, equity_value, cash_pct)."""
|
||||||
# TODO: implement
|
d = sample_portfolio.to_dict()
|
||||||
raise NotImplementedError
|
assert "total_value" not in d
|
||||||
|
assert "equity_value" not in d
|
||||||
|
assert "cash_pct" not in d
|
||||||
|
|
||||||
|
|
||||||
|
def test_portfolio_from_dict_defaults_optional_fields():
|
||||||
|
"""from_dict() must tolerate missing optional fields."""
|
||||||
|
minimal = {
|
||||||
|
"portfolio_id": "pid-1",
|
||||||
|
"name": "Minimal",
|
||||||
|
"cash": 1000.0,
|
||||||
|
"initial_cash": 1000.0,
|
||||||
|
}
|
||||||
|
p = Portfolio.from_dict(minimal)
|
||||||
|
assert p.currency == "USD"
|
||||||
|
assert p.created_at == ""
|
||||||
|
assert p.updated_at == ""
|
||||||
|
assert p.report_path is None
|
||||||
|
assert p.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -46,14 +75,23 @@ def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
|
||||||
|
|
||||||
def test_holding_to_dict_round_trip(sample_holding):
|
def test_holding_to_dict_round_trip(sample_holding):
|
||||||
"""Holding.to_dict() -> Holding.from_dict() must be lossless."""
|
"""Holding.to_dict() -> Holding.from_dict() must be lossless."""
|
||||||
# TODO: implement
|
d = sample_holding.to_dict()
|
||||||
raise NotImplementedError
|
restored = Holding.from_dict(d)
|
||||||
|
assert restored.holding_id == sample_holding.holding_id
|
||||||
|
assert restored.portfolio_id == sample_holding.portfolio_id
|
||||||
|
assert restored.ticker == sample_holding.ticker
|
||||||
|
assert restored.shares == sample_holding.shares
|
||||||
|
assert restored.avg_cost == sample_holding.avg_cost
|
||||||
|
assert restored.sector == sample_holding.sector
|
||||||
|
assert restored.industry == sample_holding.industry
|
||||||
|
|
||||||
|
|
||||||
def test_holding_to_dict_excludes_runtime_fields(sample_holding):
|
def test_holding_to_dict_excludes_runtime_fields(sample_holding):
|
||||||
"""to_dict() must not include current_price, current_value, weight, etc."""
|
"""to_dict() must not include current_price, current_value, weight, etc."""
|
||||||
# TODO: implement
|
d = sample_holding.to_dict()
|
||||||
raise NotImplementedError
|
for field in ("current_price", "current_value", "cost_basis",
|
||||||
|
"unrealized_pnl", "unrealized_pnl_pct", "weight"):
|
||||||
|
assert field not in d
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -63,8 +101,19 @@ def test_holding_to_dict_excludes_runtime_fields(sample_holding):
|
||||||
|
|
||||||
def test_trade_to_dict_round_trip(sample_trade):
|
def test_trade_to_dict_round_trip(sample_trade):
|
||||||
"""Trade.to_dict() -> Trade.from_dict() must be lossless."""
|
"""Trade.to_dict() -> Trade.from_dict() must be lossless."""
|
||||||
# TODO: implement
|
d = sample_trade.to_dict()
|
||||||
raise NotImplementedError
|
restored = Trade.from_dict(d)
|
||||||
|
assert restored.trade_id == sample_trade.trade_id
|
||||||
|
assert restored.portfolio_id == sample_trade.portfolio_id
|
||||||
|
assert restored.ticker == sample_trade.ticker
|
||||||
|
assert restored.action == sample_trade.action
|
||||||
|
assert restored.shares == sample_trade.shares
|
||||||
|
assert restored.price == sample_trade.price
|
||||||
|
assert restored.total_value == sample_trade.total_value
|
||||||
|
assert restored.trade_date == sample_trade.trade_date
|
||||||
|
assert restored.rationale == sample_trade.rationale
|
||||||
|
assert restored.signal_source == sample_trade.signal_source
|
||||||
|
assert restored.metadata == sample_trade.metadata
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -74,8 +123,35 @@ def test_trade_to_dict_round_trip(sample_trade):
|
||||||
|
|
||||||
def test_snapshot_to_dict_round_trip(sample_snapshot):
|
def test_snapshot_to_dict_round_trip(sample_snapshot):
|
||||||
"""PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip."""
|
"""PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip."""
|
||||||
# TODO: implement
|
d = sample_snapshot.to_dict()
|
||||||
raise NotImplementedError
|
restored = PortfolioSnapshot.from_dict(d)
|
||||||
|
assert restored.snapshot_id == sample_snapshot.snapshot_id
|
||||||
|
assert restored.portfolio_id == sample_snapshot.portfolio_id
|
||||||
|
assert restored.snapshot_date == sample_snapshot.snapshot_date
|
||||||
|
assert restored.total_value == sample_snapshot.total_value
|
||||||
|
assert restored.cash == sample_snapshot.cash
|
||||||
|
assert restored.equity_value == sample_snapshot.equity_value
|
||||||
|
assert restored.num_positions == sample_snapshot.num_positions
|
||||||
|
assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot
|
||||||
|
assert restored.metadata == sample_snapshot.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def test_snapshot_from_dict_parses_holdings_snapshot_json_string():
|
||||||
|
"""from_dict() must parse holdings_snapshot when it arrives as a JSON string."""
|
||||||
|
import json
|
||||||
|
holdings = [{"ticker": "AAPL", "shares": 10.0}]
|
||||||
|
data = {
|
||||||
|
"snapshot_id": "snap-1",
|
||||||
|
"portfolio_id": "pid-1",
|
||||||
|
"snapshot_date": "2026-03-20",
|
||||||
|
"total_value": 110_000.0,
|
||||||
|
"cash": 10_000.0,
|
||||||
|
"equity_value": 100_000.0,
|
||||||
|
"num_positions": 1,
|
||||||
|
"holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase
|
||||||
|
}
|
||||||
|
snap = PortfolioSnapshot.from_dict(data)
|
||||||
|
assert snap.holdings_snapshot == holdings
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -85,34 +161,50 @@ def test_snapshot_to_dict_round_trip(sample_snapshot):
|
||||||
|
|
||||||
def test_holding_enrich_computes_current_value(sample_holding):
|
def test_holding_enrich_computes_current_value(sample_holding):
|
||||||
"""enrich() must set current_value = current_price * shares."""
|
"""enrich() must set current_value = current_price * shares."""
|
||||||
# TODO: implement
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||||
# sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
assert sample_holding.current_value == 200.0 * sample_holding.shares
|
||||||
# assert sample_holding.current_value == 200.0 * sample_holding.shares
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def test_holding_enrich_computes_unrealized_pnl(sample_holding):
|
def test_holding_enrich_computes_unrealized_pnl(sample_holding):
|
||||||
"""enrich() must set unrealized_pnl = current_value - cost_basis."""
|
"""enrich() must set unrealized_pnl = current_value - cost_basis."""
|
||||||
# TODO: implement
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||||
raise NotImplementedError
|
expected_cost_basis = sample_holding.avg_cost * sample_holding.shares
|
||||||
|
expected_pnl = sample_holding.current_value - expected_cost_basis
|
||||||
|
assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl)
|
||||||
|
|
||||||
|
|
||||||
|
def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding):
|
||||||
|
"""enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis."""
|
||||||
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||||
|
cost_basis = sample_holding.avg_cost * sample_holding.shares
|
||||||
|
expected_pct = sample_holding.unrealized_pnl / cost_basis
|
||||||
|
assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct)
|
||||||
|
|
||||||
|
|
||||||
def test_holding_enrich_computes_weight(sample_holding):
|
def test_holding_enrich_computes_weight(sample_holding):
|
||||||
"""enrich() must set weight = current_value / portfolio_total_value."""
|
"""enrich() must set weight = current_value / portfolio_total_value."""
|
||||||
# TODO: implement
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||||
raise NotImplementedError
|
expected_weight = sample_holding.current_value / 100_000.0
|
||||||
|
assert sample_holding.weight == pytest.approx(expected_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def test_holding_enrich_returns_self(sample_holding):
|
||||||
|
"""enrich() must return self for chaining."""
|
||||||
|
result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||||
|
assert result is sample_holding
|
||||||
|
|
||||||
|
|
||||||
def test_holding_enrich_handles_zero_cost(sample_holding):
|
def test_holding_enrich_handles_zero_cost(sample_holding):
|
||||||
"""When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError)."""
|
"""When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError)."""
|
||||||
# TODO: implement
|
sample_holding.avg_cost = 0.0
|
||||||
raise NotImplementedError
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
||||||
|
assert sample_holding.unrealized_pnl_pct == 0.0
|
||||||
|
|
||||||
|
|
||||||
def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
|
def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
|
||||||
"""When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError)."""
|
"""When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError)."""
|
||||||
# TODO: implement
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0)
|
||||||
raise NotImplementedError
|
assert sample_holding.weight == 0.0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -122,11 +214,36 @@ def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
|
||||||
|
|
||||||
def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding):
|
def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding):
|
||||||
"""Portfolio.enrich() must compute total_value = cash + sum(holding.current_value)."""
|
"""Portfolio.enrich() must compute total_value = cash + sum(holding.current_value)."""
|
||||||
# TODO: implement
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
||||||
raise NotImplementedError
|
sample_portfolio.enrich([sample_holding])
|
||||||
|
expected_equity = 200.0 * sample_holding.shares
|
||||||
|
assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity)
|
||||||
|
|
||||||
|
|
||||||
|
def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding):
|
||||||
|
"""Portfolio.enrich() must set equity_value = sum(holding.current_value)."""
|
||||||
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
||||||
|
sample_portfolio.enrich([sample_holding])
|
||||||
|
assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares)
|
||||||
|
|
||||||
|
|
||||||
def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding):
|
def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding):
|
||||||
"""Portfolio.enrich() must compute cash_pct = cash / total_value."""
|
"""Portfolio.enrich() must compute cash_pct = cash / total_value."""
|
||||||
# TODO: implement
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
||||||
raise NotImplementedError
|
sample_portfolio.enrich([sample_holding])
|
||||||
|
expected_pct = sample_portfolio.cash / sample_portfolio.total_value
|
||||||
|
assert sample_portfolio.cash_pct == pytest.approx(expected_pct)
|
||||||
|
|
||||||
|
|
||||||
|
def test_portfolio_enrich_returns_self(sample_portfolio):
|
||||||
|
"""enrich() must return self for chaining."""
|
||||||
|
result = sample_portfolio.enrich([])
|
||||||
|
assert result is sample_portfolio
|
||||||
|
|
||||||
|
|
||||||
|
def test_portfolio_enrich_no_holdings(sample_portfolio):
|
||||||
|
"""Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash."""
|
||||||
|
sample_portfolio.enrich([])
|
||||||
|
assert sample_portfolio.equity_value == 0.0
|
||||||
|
assert sample_portfolio.total_value == sample_portfolio.cash
|
||||||
|
assert sample_portfolio.cash_pct == 1.0
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,14 @@ Run::
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tradingagents.portfolio.exceptions import ReportStoreError
|
||||||
|
from tradingagents.portfolio.report_store import ReportStore
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Macro scan
|
# Macro scan
|
||||||
|
|
@ -24,19 +28,17 @@ import pytest
|
||||||
|
|
||||||
def test_save_and_load_scan(report_store, tmp_reports):
|
def test_save_and_load_scan(report_store, tmp_reports):
|
||||||
"""save_scan() then load_scan() must return the original data."""
|
"""save_scan() then load_scan() must return the original data."""
|
||||||
# TODO: implement
|
data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
|
||||||
# data = {"watchlist": ["AAPL", "MSFT"], "date": "2026-03-20"}
|
path = report_store.save_scan("2026-03-20", data)
|
||||||
# path = report_store.save_scan("2026-03-20", data)
|
assert path.exists()
|
||||||
# assert path.exists()
|
loaded = report_store.load_scan("2026-03-20")
|
||||||
# loaded = report_store.load_scan("2026-03-20")
|
assert loaded == data
|
||||||
# assert loaded == data
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_scan_returns_none_for_missing_file(report_store):
|
def test_load_scan_returns_none_for_missing_file(report_store):
|
||||||
"""load_scan() must return None when the file does not exist."""
|
"""load_scan() must return None when the file does not exist."""
|
||||||
# TODO: implement
|
result = report_store.load_scan("1900-01-01")
|
||||||
raise NotImplementedError
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -46,14 +48,21 @@ def test_load_scan_returns_none_for_missing_file(report_store):
|
||||||
|
|
||||||
def test_save_and_load_analysis(report_store):
|
def test_save_and_load_analysis(report_store):
|
||||||
"""save_analysis() then load_analysis() must return the original data."""
|
"""save_analysis() then load_analysis() must return the original data."""
|
||||||
# TODO: implement
|
data = {"ticker": "AAPL", "recommendation": "BUY", "score": 0.92}
|
||||||
raise NotImplementedError
|
report_store.save_analysis("2026-03-20", "AAPL", data)
|
||||||
|
loaded = report_store.load_analysis("2026-03-20", "AAPL")
|
||||||
|
assert loaded == data
|
||||||
|
|
||||||
|
|
||||||
def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
|
def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
|
||||||
"""Ticker symbol must be stored as uppercase in the directory path."""
|
"""Ticker symbol must be stored as uppercase in the directory path."""
|
||||||
# TODO: implement
|
data = {"ticker": "aapl"}
|
||||||
raise NotImplementedError
|
report_store.save_analysis("2026-03-20", "aapl", data)
|
||||||
|
expected = tmp_reports / "daily" / "2026-03-20" / "AAPL" / "complete_report.json"
|
||||||
|
assert expected.exists()
|
||||||
|
# load with lowercase should still work
|
||||||
|
loaded = report_store.load_analysis("2026-03-20", "aapl")
|
||||||
|
assert loaded == data
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -63,14 +72,16 @@ def test_analysis_ticker_stored_as_uppercase(report_store, tmp_reports):
|
||||||
|
|
||||||
def test_save_and_load_holding_review(report_store):
|
def test_save_and_load_holding_review(report_store):
|
||||||
"""save_holding_review() then load_holding_review() must round-trip."""
|
"""save_holding_review() then load_holding_review() must round-trip."""
|
||||||
# TODO: implement
|
data = {"ticker": "MSFT", "verdict": "HOLD", "price_target": 420.0}
|
||||||
raise NotImplementedError
|
report_store.save_holding_review("2026-03-20", "MSFT", data)
|
||||||
|
loaded = report_store.load_holding_review("2026-03-20", "MSFT")
|
||||||
|
assert loaded == data
|
||||||
|
|
||||||
|
|
||||||
def test_load_holding_review_returns_none_for_missing(report_store):
|
def test_load_holding_review_returns_none_for_missing(report_store):
|
||||||
"""load_holding_review() must return None when the file does not exist."""
|
"""load_holding_review() must return None when the file does not exist."""
|
||||||
# TODO: implement
|
result = report_store.load_holding_review("1900-01-01", "ZZZZ")
|
||||||
raise NotImplementedError
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -80,8 +91,10 @@ def test_load_holding_review_returns_none_for_missing(report_store):
|
||||||
|
|
||||||
def test_save_and_load_risk_metrics(report_store):
|
def test_save_and_load_risk_metrics(report_store):
|
||||||
"""save_risk_metrics() then load_risk_metrics() must round-trip."""
|
"""save_risk_metrics() then load_risk_metrics() must round-trip."""
|
||||||
# TODO: implement
|
data = {"sharpe": 1.35, "sortino": 1.8, "max_drawdown": -0.12}
|
||||||
raise NotImplementedError
|
report_store.save_risk_metrics("2026-03-20", "pid-123", data)
|
||||||
|
loaded = report_store.load_risk_metrics("2026-03-20", "pid-123")
|
||||||
|
assert loaded == data
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -91,37 +104,46 @@ def test_save_and_load_risk_metrics(report_store):
|
||||||
|
|
||||||
def test_save_and_load_pm_decision_json(report_store):
|
def test_save_and_load_pm_decision_json(report_store):
|
||||||
"""save_pm_decision() then load_pm_decision() must round-trip JSON."""
|
"""save_pm_decision() then load_pm_decision() must round-trip JSON."""
|
||||||
# TODO: implement
|
decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
|
||||||
# decision = {"sells": [], "buys": [{"ticker": "AAPL", "shares": 10}]}
|
report_store.save_pm_decision("2026-03-20", "pid-123", decision)
|
||||||
# report_store.save_pm_decision("2026-03-20", "pid-123", decision)
|
loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
|
||||||
# loaded = report_store.load_pm_decision("2026-03-20", "pid-123")
|
assert loaded == decision
|
||||||
# assert loaded == decision
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports):
|
def test_save_pm_decision_writes_markdown_when_provided(report_store, tmp_reports):
|
||||||
"""When markdown is passed to save_pm_decision(), .md file must be written."""
|
"""When markdown is passed to save_pm_decision(), .md file must be written."""
|
||||||
# TODO: implement
|
decision = {"sells": [], "buys": []}
|
||||||
raise NotImplementedError
|
md_text = "# Decision\n\nHold everything."
|
||||||
|
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=md_text)
|
||||||
|
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
|
||||||
|
assert md_path.exists()
|
||||||
|
assert md_path.read_text(encoding="utf-8") == md_text
|
||||||
|
|
||||||
|
|
||||||
def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports):
|
def test_save_pm_decision_no_markdown_file_when_not_provided(report_store, tmp_reports):
|
||||||
"""When markdown=None, no .md file should be written."""
|
"""When markdown=None, no .md file should be written."""
|
||||||
# TODO: implement
|
decision = {"sells": [], "buys": []}
|
||||||
raise NotImplementedError
|
report_store.save_pm_decision("2026-03-20", "pid-123", decision, markdown=None)
|
||||||
|
md_path = tmp_reports / "daily" / "2026-03-20" / "portfolio" / "pid-123_pm_decision.md"
|
||||||
|
assert not md_path.exists()
|
||||||
|
|
||||||
|
|
||||||
def test_load_pm_decision_returns_none_for_missing(report_store):
|
def test_load_pm_decision_returns_none_for_missing(report_store):
|
||||||
"""load_pm_decision() must return None when the file does not exist."""
|
"""load_pm_decision() must return None when the file does not exist."""
|
||||||
# TODO: implement
|
result = report_store.load_pm_decision("1900-01-01", "pid-none")
|
||||||
raise NotImplementedError
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
def test_list_pm_decisions(report_store):
|
def test_list_pm_decisions(report_store):
|
||||||
"""list_pm_decisions() must return all saved decision paths, newest first."""
|
"""list_pm_decisions() must return all saved decision paths, newest first."""
|
||||||
# TODO: implement
|
dates = ["2026-03-18", "2026-03-19", "2026-03-20"]
|
||||||
# Save decisions for multiple dates, verify order
|
for d in dates:
|
||||||
raise NotImplementedError
|
report_store.save_pm_decision(d, "pid-abc", {"date": d})
|
||||||
|
paths = report_store.list_pm_decisions("pid-abc")
|
||||||
|
assert len(paths) == 3
|
||||||
|
# Sorted newest first by ISO date string ordering
|
||||||
|
date_parts = [p.parent.parent.name for p in paths]
|
||||||
|
assert date_parts == sorted(dates, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -131,15 +153,24 @@ def test_list_pm_decisions(report_store):
|
||||||
|
|
||||||
def test_directories_created_on_write(report_store, tmp_reports):
|
def test_directories_created_on_write(report_store, tmp_reports):
|
||||||
"""Directories must be created automatically on first write."""
|
"""Directories must be created automatically on first write."""
|
||||||
# TODO: implement
|
target_dir = tmp_reports / "daily" / "2026-03-20" / "portfolio"
|
||||||
# assert not (tmp_reports / "daily" / "2026-03-20" / "portfolio").exists()
|
assert not target_dir.exists()
|
||||||
# report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
|
report_store.save_risk_metrics("2026-03-20", "pid-123", {"sharpe": 1.2})
|
||||||
# assert (tmp_reports / "daily" / "2026-03-20" / "portfolio").is_dir()
|
assert target_dir.is_dir()
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_formatted_with_indent(report_store, tmp_reports):
|
def test_json_formatted_with_indent(report_store, tmp_reports):
|
||||||
"""Written JSON files must use indent=2 for human readability."""
|
"""Written JSON files must use indent=2 for human readability."""
|
||||||
# TODO: implement
|
data = {"key": "value", "nested": {"a": 1}}
|
||||||
# Write a file, read the raw bytes, verify indentation
|
path = report_store.save_scan("2026-03-20", data)
|
||||||
raise NotImplementedError
|
raw = path.read_text(encoding="utf-8")
|
||||||
|
# indent=2 means lines like ' "key": ...'
|
||||||
|
assert ' "key"' in raw
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_json_raises_on_corrupt_file(report_store, tmp_reports):
|
||||||
|
"""_read_json must raise ReportStoreError for corrupt JSON."""
|
||||||
|
corrupt = tmp_reports / "corrupt.json"
|
||||||
|
corrupt.write_text("not valid json{{{", encoding="utf-8")
|
||||||
|
with pytest.raises(ReportStoreError):
|
||||||
|
report_store._read_json(corrupt)
|
||||||
|
|
|
||||||
|
|
@ -65,16 +65,14 @@ CREATE TABLE IF NOT EXISTS trades (
|
||||||
trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
trade_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE,
|
portfolio_id UUID NOT NULL REFERENCES portfolios(portfolio_id) ON DELETE CASCADE,
|
||||||
ticker TEXT NOT NULL,
|
ticker TEXT NOT NULL,
|
||||||
action TEXT NOT NULL,
|
action TEXT NOT NULL CHECK (action IN ('BUY', 'SELL')),
|
||||||
shares NUMERIC(18,6) NOT NULL CHECK (shares > 0),
|
shares NUMERIC(18,6) NOT NULL CHECK (shares > 0),
|
||||||
price NUMERIC(18,4) NOT NULL CHECK (price > 0),
|
price NUMERIC(18,4) NOT NULL CHECK (price > 0),
|
||||||
total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0),
|
total_value NUMERIC(18,4) NOT NULL CHECK (total_value > 0),
|
||||||
trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
trade_date TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||||
rationale TEXT, -- PM agent rationale for this trade
|
rationale TEXT, -- PM agent rationale for this trade
|
||||||
signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent'
|
signal_source TEXT, -- 'scanner' | 'holding_review' | 'pm_agent'
|
||||||
metadata JSONB NOT NULL DEFAULT '{}',
|
metadata JSONB NOT NULL DEFAULT '{}'
|
||||||
|
|
||||||
CONSTRAINT trades_action_values CHECK (action IN ('BUY', 'SELL'))
|
|
||||||
);
|
);
|
||||||
|
|
||||||
COMMENT ON TABLE trades IS
|
COMMENT ON TABLE trades IS
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,26 @@ All models are Python ``dataclass`` types with:
|
||||||
- ``from_dict()`` class method for deserialisation
|
- ``from_dict()`` class method for deserialisation
|
||||||
- ``enrich()`` for attaching runtime-computed fields
|
- ``enrich()`` for attaching runtime-computed fields
|
||||||
|
|
||||||
|
**float vs Decimal** — monetary fields (cash, price, shares, etc.) use plain
|
||||||
|
``float`` throughout. Rationale:
|
||||||
|
|
||||||
|
1. This is **mock trading only** — no real money changes hands. The cost of a
|
||||||
|
subtle floating-point rounding error is zero.
|
||||||
|
2. All upstream data sources (yfinance, Alpha Vantage, Finnhub) return ``float``
|
||||||
|
already. Converting to ``Decimal`` at the boundary would require a custom
|
||||||
|
JSON encoder *and* decoder everywhere, for no practical gain.
|
||||||
|
3. ``json.dumps`` serialises ``float`` natively; ``Decimal`` raises
|
||||||
|
``TypeError`` without a custom encoder.
|
||||||
|
4. If this ever becomes real-money trading, replace ``float`` with
|
||||||
|
``decimal.Decimal`` and add a ``DecimalEncoder`` — the interface is
|
||||||
|
identical and the change is localised to this file.
|
||||||
|
|
||||||
See ``docs/portfolio/02_data_models.md`` for full field specifications.
|
See ``docs/portfolio/02_data_models.md`` for full field specifications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -48,8 +63,17 @@ class Portfolio:
|
||||||
|
|
||||||
Runtime-computed fields are excluded.
|
Runtime-computed fields are excluded.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
return {
|
||||||
raise NotImplementedError
|
"portfolio_id": self.portfolio_id,
|
||||||
|
"name": self.name,
|
||||||
|
"cash": self.cash,
|
||||||
|
"initial_cash": self.initial_cash,
|
||||||
|
"currency": self.currency,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"updated_at": self.updated_at,
|
||||||
|
"report_path": self.report_path,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "Portfolio":
|
def from_dict(cls, data: dict[str, Any]) -> "Portfolio":
|
||||||
|
|
@ -57,8 +81,17 @@ class Portfolio:
|
||||||
|
|
||||||
Missing optional fields default gracefully. Extra keys are ignored.
|
Missing optional fields default gracefully. Extra keys are ignored.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
return cls(
|
||||||
raise NotImplementedError
|
portfolio_id=data["portfolio_id"],
|
||||||
|
name=data["name"],
|
||||||
|
cash=float(data["cash"]),
|
||||||
|
initial_cash=float(data["initial_cash"]),
|
||||||
|
currency=data.get("currency", "USD"),
|
||||||
|
created_at=data.get("created_at", ""),
|
||||||
|
updated_at=data.get("updated_at", ""),
|
||||||
|
report_path=data.get("report_path"),
|
||||||
|
metadata=data.get("metadata") or {},
|
||||||
|
)
|
||||||
|
|
||||||
def enrich(self, holdings: list["Holding"]) -> "Portfolio":
|
def enrich(self, holdings: list["Holding"]) -> "Portfolio":
|
||||||
"""Compute total_value, equity_value, cash_pct from holdings.
|
"""Compute total_value, equity_value, cash_pct from holdings.
|
||||||
|
|
@ -69,8 +102,12 @@ class Portfolio:
|
||||||
holdings: List of Holding objects with current_value populated
|
holdings: List of Holding objects with current_value populated
|
||||||
(i.e., ``holding.enrich()`` already called).
|
(i.e., ``holding.enrich()`` already called).
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
self.equity_value = sum(
|
||||||
raise NotImplementedError
|
h.current_value for h in holdings if h.current_value is not None
|
||||||
|
)
|
||||||
|
self.total_value = self.cash + self.equity_value
|
||||||
|
self.cash_pct = self.cash / self.total_value if self.total_value != 0.0 else 0.0
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -106,14 +143,32 @@ class Holding:
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Serialise stored fields only (runtime-computed fields excluded)."""
|
"""Serialise stored fields only (runtime-computed fields excluded)."""
|
||||||
# TODO: implement
|
return {
|
||||||
raise NotImplementedError
|
"holding_id": self.holding_id,
|
||||||
|
"portfolio_id": self.portfolio_id,
|
||||||
|
"ticker": self.ticker,
|
||||||
|
"shares": self.shares,
|
||||||
|
"avg_cost": self.avg_cost,
|
||||||
|
"sector": self.sector,
|
||||||
|
"industry": self.industry,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"updated_at": self.updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "Holding":
|
def from_dict(cls, data: dict[str, Any]) -> "Holding":
|
||||||
"""Deserialise from a DB row or JSON dict."""
|
"""Deserialise from a DB row or JSON dict."""
|
||||||
# TODO: implement
|
return cls(
|
||||||
raise NotImplementedError
|
holding_id=data["holding_id"],
|
||||||
|
portfolio_id=data["portfolio_id"],
|
||||||
|
ticker=data["ticker"],
|
||||||
|
shares=float(data["shares"]),
|
||||||
|
avg_cost=float(data["avg_cost"]),
|
||||||
|
sector=data.get("sector"),
|
||||||
|
industry=data.get("industry"),
|
||||||
|
created_at=data.get("created_at", ""),
|
||||||
|
updated_at=data.get("updated_at", ""),
|
||||||
|
)
|
||||||
|
|
||||||
def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding":
|
def enrich(self, current_price: float, portfolio_total_value: float) -> "Holding":
|
||||||
"""Populate runtime-computed fields in-place and return self.
|
"""Populate runtime-computed fields in-place and return self.
|
||||||
|
|
@ -129,8 +184,17 @@ class Holding:
|
||||||
current_price: Latest market price for this ticker.
|
current_price: Latest market price for this ticker.
|
||||||
portfolio_total_value: Total portfolio value (cash + equity).
|
portfolio_total_value: Total portfolio value (cash + equity).
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
self.current_price = current_price
|
||||||
raise NotImplementedError
|
self.current_value = current_price * self.shares
|
||||||
|
self.cost_basis = self.avg_cost * self.shares
|
||||||
|
self.unrealized_pnl = self.current_value - self.cost_basis
|
||||||
|
self.unrealized_pnl_pct = (
|
||||||
|
self.unrealized_pnl / self.cost_basis if self.cost_basis != 0.0 else 0.0
|
||||||
|
)
|
||||||
|
self.weight = (
|
||||||
|
self.current_value / portfolio_total_value if portfolio_total_value != 0.0 else 0.0
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -159,14 +223,36 @@ class Trade:
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Serialise all fields."""
|
"""Serialise all fields."""
|
||||||
# TODO: implement
|
return {
|
||||||
raise NotImplementedError
|
"trade_id": self.trade_id,
|
||||||
|
"portfolio_id": self.portfolio_id,
|
||||||
|
"ticker": self.ticker,
|
||||||
|
"action": self.action,
|
||||||
|
"shares": self.shares,
|
||||||
|
"price": self.price,
|
||||||
|
"total_value": self.total_value,
|
||||||
|
"trade_date": self.trade_date,
|
||||||
|
"rationale": self.rationale,
|
||||||
|
"signal_source": self.signal_source,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "Trade":
|
def from_dict(cls, data: dict[str, Any]) -> "Trade":
|
||||||
"""Deserialise from a DB row or JSON dict."""
|
"""Deserialise from a DB row or JSON dict."""
|
||||||
# TODO: implement
|
return cls(
|
||||||
raise NotImplementedError
|
trade_id=data["trade_id"],
|
||||||
|
portfolio_id=data["portfolio_id"],
|
||||||
|
ticker=data["ticker"],
|
||||||
|
action=data["action"],
|
||||||
|
shares=float(data["shares"]),
|
||||||
|
price=float(data["price"]),
|
||||||
|
total_value=float(data["total_value"]),
|
||||||
|
trade_date=data.get("trade_date", ""),
|
||||||
|
rationale=data.get("rationale"),
|
||||||
|
signal_source=data.get("signal_source"),
|
||||||
|
metadata=data.get("metadata") or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -194,8 +280,17 @@ class PortfolioSnapshot:
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Serialise all fields. ``holdings_snapshot`` is already a list[dict]."""
|
"""Serialise all fields. ``holdings_snapshot`` is already a list[dict]."""
|
||||||
# TODO: implement
|
return {
|
||||||
raise NotImplementedError
|
"snapshot_id": self.snapshot_id,
|
||||||
|
"portfolio_id": self.portfolio_id,
|
||||||
|
"snapshot_date": self.snapshot_date,
|
||||||
|
"total_value": self.total_value,
|
||||||
|
"cash": self.cash,
|
||||||
|
"equity_value": self.equity_value,
|
||||||
|
"num_positions": self.num_positions,
|
||||||
|
"holdings_snapshot": self.holdings_snapshot,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot":
|
def from_dict(cls, data: dict[str, Any]) -> "PortfolioSnapshot":
|
||||||
|
|
@ -203,5 +298,17 @@ class PortfolioSnapshot:
|
||||||
|
|
||||||
``holdings_snapshot`` is parsed from a JSON string when needed.
|
``holdings_snapshot`` is parsed from a JSON string when needed.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
holdings_snapshot = data.get("holdings_snapshot", [])
|
||||||
raise NotImplementedError
|
if isinstance(holdings_snapshot, str):
|
||||||
|
holdings_snapshot = json.loads(holdings_snapshot)
|
||||||
|
return cls(
|
||||||
|
snapshot_id=data["snapshot_id"],
|
||||||
|
portfolio_id=data["portfolio_id"],
|
||||||
|
snapshot_date=data["snapshot_date"],
|
||||||
|
total_value=float(data["total_value"]),
|
||||||
|
cash=float(data["cash"]),
|
||||||
|
equity_value=float(data["equity_value"]),
|
||||||
|
num_positions=int(data["num_positions"]),
|
||||||
|
holdings_snapshot=holdings_snapshot,
|
||||||
|
metadata=data.get("metadata") or {},
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,12 @@ Usage::
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from tradingagents.portfolio.exceptions import ReportStoreError
|
||||||
|
|
||||||
|
|
||||||
class ReportStore:
|
class ReportStore:
|
||||||
"""Filesystem document store for all portfolio-related reports.
|
"""Filesystem document store for all portfolio-related reports.
|
||||||
|
|
@ -48,8 +51,7 @@ class ReportStore:
|
||||||
Override via the ``PORTFOLIO_DATA_DIR`` env var or
|
Override via the ``PORTFOLIO_DATA_DIR`` env var or
|
||||||
``get_portfolio_config()["data_dir"]``.
|
``get_portfolio_config()["data_dir"]``.
|
||||||
"""
|
"""
|
||||||
# TODO: implement — store Path(base_dir), resolve as needed
|
self._base_dir = Path(base_dir)
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
|
|
@ -60,8 +62,7 @@ class ReportStore:
|
||||||
|
|
||||||
Path: ``{base_dir}/daily/{date}/portfolio/``
|
Path: ``{base_dir}/daily/{date}/portfolio/``
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
return self._base_dir / "daily" / date / "portfolio"
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
|
def _write_json(self, path: Path, data: dict[str, Any]) -> Path:
|
||||||
"""Write a dict to a JSON file, creating parent directories as needed.
|
"""Write a dict to a JSON file, creating parent directories as needed.
|
||||||
|
|
@ -76,8 +77,12 @@ class ReportStore:
|
||||||
Raises:
|
Raises:
|
||||||
ReportStoreError: On filesystem write failure.
|
ReportStoreError: On filesystem write failure.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
try:
|
||||||
raise NotImplementedError
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||||
|
return path
|
||||||
|
except OSError as exc:
|
||||||
|
raise ReportStoreError(f"Failed to write {path}: {exc}") from exc
|
||||||
|
|
||||||
def _read_json(self, path: Path) -> dict[str, Any] | None:
|
def _read_json(self, path: Path) -> dict[str, Any] | None:
|
||||||
"""Read a JSON file, returning None if the file does not exist.
|
"""Read a JSON file, returning None if the file does not exist.
|
||||||
|
|
@ -85,8 +90,12 @@ class ReportStore:
|
||||||
Raises:
|
Raises:
|
||||||
ReportStoreError: On JSON parse error (file exists but is corrupt).
|
ReportStoreError: On JSON parse error (file exists but is corrupt).
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
if not path.exists():
|
||||||
raise NotImplementedError
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ReportStoreError(f"Corrupt JSON at {path}: {exc}") from exc
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Macro Scan
|
# Macro Scan
|
||||||
|
|
@ -104,13 +113,13 @@ class ReportStore:
|
||||||
Returns:
|
Returns:
|
||||||
Path of the written file.
|
Path of the written file.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
|
||||||
raise NotImplementedError
|
return self._write_json(path, data)
|
||||||
|
|
||||||
def load_scan(self, date: str) -> dict[str, Any] | None:
|
def load_scan(self, date: str) -> dict[str, Any] | None:
|
||||||
"""Load macro scan summary. Returns None if the file does not exist."""
|
"""Load macro scan summary. Returns None if the file does not exist."""
|
||||||
# TODO: implement
|
path = self._base_dir / "daily" / date / "market" / "macro_scan_summary.json"
|
||||||
raise NotImplementedError
|
return self._read_json(path)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Per-Ticker Analysis
|
# Per-Ticker Analysis
|
||||||
|
|
@ -126,13 +135,13 @@ class ReportStore:
|
||||||
ticker: Ticker symbol (stored as uppercase).
|
ticker: Ticker symbol (stored as uppercase).
|
||||||
data: Analysis output dict.
|
data: Analysis output dict.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
|
||||||
raise NotImplementedError
|
return self._write_json(path, data)
|
||||||
|
|
||||||
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
|
def load_analysis(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||||
"""Load per-ticker analysis JSON. Returns None if the file does not exist."""
|
"""Load per-ticker analysis JSON. Returns None if the file does not exist."""
|
||||||
# TODO: implement
|
path = self._base_dir / "daily" / date / ticker.upper() / "complete_report.json"
|
||||||
raise NotImplementedError
|
return self._read_json(path)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Holding Reviews
|
# Holding Reviews
|
||||||
|
|
@ -153,13 +162,13 @@ class ReportStore:
|
||||||
ticker: Ticker symbol (stored as uppercase).
|
ticker: Ticker symbol (stored as uppercase).
|
||||||
data: HoldingReviewerAgent output dict.
|
data: HoldingReviewerAgent output dict.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
|
||||||
raise NotImplementedError
|
return self._write_json(path, data)
|
||||||
|
|
||||||
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
|
def load_holding_review(self, date: str, ticker: str) -> dict[str, Any] | None:
|
||||||
"""Load holding review output. Returns None if the file does not exist."""
|
"""Load holding review output. Returns None if the file does not exist."""
|
||||||
# TODO: implement
|
path = self._portfolio_dir(date) / f"{ticker.upper()}_holding_review.json"
|
||||||
raise NotImplementedError
|
return self._read_json(path)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Risk Metrics
|
# Risk Metrics
|
||||||
|
|
@ -180,8 +189,8 @@ class ReportStore:
|
||||||
portfolio_id: UUID of the target portfolio.
|
portfolio_id: UUID of the target portfolio.
|
||||||
data: Risk metrics dict (Sharpe, Sortino, VaR, etc.).
|
data: Risk metrics dict (Sharpe, Sortino, VaR, etc.).
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
|
||||||
raise NotImplementedError
|
return self._write_json(path, data)
|
||||||
|
|
||||||
def load_risk_metrics(
|
def load_risk_metrics(
|
||||||
self,
|
self,
|
||||||
|
|
@ -189,8 +198,8 @@ class ReportStore:
|
||||||
portfolio_id: str,
|
portfolio_id: str,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Load risk metrics. Returns None if the file does not exist."""
|
"""Load risk metrics. Returns None if the file does not exist."""
|
||||||
# TODO: implement
|
path = self._portfolio_dir(date) / f"{portfolio_id}_risk_metrics.json"
|
||||||
raise NotImplementedError
|
return self._read_json(path)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# PM Decisions
|
# PM Decisions
|
||||||
|
|
@ -218,8 +227,15 @@ class ReportStore:
|
||||||
Returns:
|
Returns:
|
||||||
Path of the written JSON file.
|
Path of the written JSON file.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
json_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
|
||||||
raise NotImplementedError
|
self._write_json(json_path, data)
|
||||||
|
if markdown is not None:
|
||||||
|
md_path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.md"
|
||||||
|
try:
|
||||||
|
md_path.write_text(markdown, encoding="utf-8")
|
||||||
|
except OSError as exc:
|
||||||
|
raise ReportStoreError(f"Failed to write {md_path}: {exc}") from exc
|
||||||
|
return json_path
|
||||||
|
|
||||||
def load_pm_decision(
|
def load_pm_decision(
|
||||||
self,
|
self,
|
||||||
|
|
@ -227,8 +243,8 @@ class ReportStore:
|
||||||
portfolio_id: str,
|
portfolio_id: str,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Load PM decision JSON. Returns None if the file does not exist."""
|
"""Load PM decision JSON. Returns None if the file does not exist."""
|
||||||
# TODO: implement
|
path = self._portfolio_dir(date) / f"{portfolio_id}_pm_decision.json"
|
||||||
raise NotImplementedError
|
return self._read_json(path)
|
||||||
|
|
||||||
def list_pm_decisions(self, portfolio_id: str) -> list[Path]:
|
def list_pm_decisions(self, portfolio_id: str) -> list[Path]:
|
||||||
"""Return all saved PM decision JSON paths for portfolio_id, newest first.
|
"""Return all saved PM decision JSON paths for portfolio_id, newest first.
|
||||||
|
|
@ -241,5 +257,5 @@ class ReportStore:
|
||||||
Returns:
|
Returns:
|
||||||
Sorted list of Path objects, newest date first.
|
Sorted list of Path objects, newest date first.
|
||||||
"""
|
"""
|
||||||
# TODO: implement
|
pattern = f"daily/*/portfolio/{portfolio_id}_pm_decision.json"
|
||||||
raise NotImplementedError
|
return sorted(self._base_dir.glob(pattern), reverse=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue