250 lines
10 KiB
Python
250 lines
10 KiB
Python
"""Tests for tradingagents/portfolio/models.py.
|
|
|
|
Tests the four dataclass models: Portfolio, Holding, Trade, PortfolioSnapshot.
|
|
|
|
Coverage targets:
|
|
- to_dict() / from_dict() round-trips
|
|
- enrich() computed-field logic
|
|
- Edge cases (zero cost basis, zero portfolio value)
|
|
|
|
Run::
|
|
|
|
pytest tests/portfolio/test_models.py -v
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from tradingagents.portfolio.models import (
|
|
Holding,
|
|
Portfolio,
|
|
PortfolioSnapshot,
|
|
Trade,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Portfolio round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_portfolio_to_dict_round_trip(sample_portfolio):
|
|
"""Portfolio.to_dict() -> Portfolio.from_dict() must be lossless."""
|
|
d = sample_portfolio.to_dict()
|
|
restored = Portfolio.from_dict(d)
|
|
assert restored.portfolio_id == sample_portfolio.portfolio_id
|
|
assert restored.name == sample_portfolio.name
|
|
assert restored.cash == sample_portfolio.cash
|
|
assert restored.initial_cash == sample_portfolio.initial_cash
|
|
assert restored.currency == sample_portfolio.currency
|
|
assert restored.created_at == sample_portfolio.created_at
|
|
assert restored.updated_at == sample_portfolio.updated_at
|
|
assert restored.report_path == sample_portfolio.report_path
|
|
assert restored.metadata == sample_portfolio.metadata
|
|
|
|
|
|
def test_portfolio_to_dict_excludes_runtime_fields(sample_portfolio):
|
|
"""to_dict() must not include computed fields (total_value, equity_value, cash_pct)."""
|
|
d = sample_portfolio.to_dict()
|
|
assert "total_value" not in d
|
|
assert "equity_value" not in d
|
|
assert "cash_pct" not in d
|
|
|
|
|
|
def test_portfolio_from_dict_defaults_optional_fields():
|
|
"""from_dict() must tolerate missing optional fields."""
|
|
minimal = {
|
|
"portfolio_id": "pid-1",
|
|
"name": "Minimal",
|
|
"cash": 1000.0,
|
|
"initial_cash": 1000.0,
|
|
}
|
|
p = Portfolio.from_dict(minimal)
|
|
assert p.currency == "USD"
|
|
assert p.created_at == ""
|
|
assert p.updated_at == ""
|
|
assert p.report_path is None
|
|
assert p.metadata == {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Holding round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_holding_to_dict_round_trip(sample_holding):
|
|
"""Holding.to_dict() -> Holding.from_dict() must be lossless."""
|
|
d = sample_holding.to_dict()
|
|
restored = Holding.from_dict(d)
|
|
assert restored.holding_id == sample_holding.holding_id
|
|
assert restored.portfolio_id == sample_holding.portfolio_id
|
|
assert restored.ticker == sample_holding.ticker
|
|
assert restored.shares == sample_holding.shares
|
|
assert restored.avg_cost == sample_holding.avg_cost
|
|
assert restored.sector == sample_holding.sector
|
|
assert restored.industry == sample_holding.industry
|
|
|
|
|
|
def test_holding_to_dict_excludes_runtime_fields(sample_holding):
|
|
"""to_dict() must not include current_price, current_value, weight, etc."""
|
|
d = sample_holding.to_dict()
|
|
for field in ("current_price", "current_value", "cost_basis",
|
|
"unrealized_pnl", "unrealized_pnl_pct", "weight"):
|
|
assert field not in d
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Trade round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_trade_to_dict_round_trip(sample_trade):
|
|
"""Trade.to_dict() -> Trade.from_dict() must be lossless."""
|
|
d = sample_trade.to_dict()
|
|
restored = Trade.from_dict(d)
|
|
assert restored.trade_id == sample_trade.trade_id
|
|
assert restored.portfolio_id == sample_trade.portfolio_id
|
|
assert restored.ticker == sample_trade.ticker
|
|
assert restored.action == sample_trade.action
|
|
assert restored.shares == sample_trade.shares
|
|
assert restored.price == sample_trade.price
|
|
assert restored.total_value == sample_trade.total_value
|
|
assert restored.trade_date == sample_trade.trade_date
|
|
assert restored.rationale == sample_trade.rationale
|
|
assert restored.signal_source == sample_trade.signal_source
|
|
assert restored.metadata == sample_trade.metadata
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PortfolioSnapshot round-trip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_snapshot_to_dict_round_trip(sample_snapshot):
|
|
"""PortfolioSnapshot.to_dict() -> PortfolioSnapshot.from_dict() round-trip."""
|
|
d = sample_snapshot.to_dict()
|
|
restored = PortfolioSnapshot.from_dict(d)
|
|
assert restored.snapshot_id == sample_snapshot.snapshot_id
|
|
assert restored.portfolio_id == sample_snapshot.portfolio_id
|
|
assert restored.snapshot_date == sample_snapshot.snapshot_date
|
|
assert restored.total_value == sample_snapshot.total_value
|
|
assert restored.cash == sample_snapshot.cash
|
|
assert restored.equity_value == sample_snapshot.equity_value
|
|
assert restored.num_positions == sample_snapshot.num_positions
|
|
assert restored.holdings_snapshot == sample_snapshot.holdings_snapshot
|
|
assert restored.metadata == sample_snapshot.metadata
|
|
|
|
|
|
def test_snapshot_from_dict_parses_holdings_snapshot_json_string():
|
|
"""from_dict() must parse holdings_snapshot when it arrives as a JSON string."""
|
|
import json
|
|
holdings = [{"ticker": "AAPL", "shares": 10.0}]
|
|
data = {
|
|
"snapshot_id": "snap-1",
|
|
"portfolio_id": "pid-1",
|
|
"snapshot_date": "2026-03-20",
|
|
"total_value": 110_000.0,
|
|
"cash": 10_000.0,
|
|
"equity_value": 100_000.0,
|
|
"num_positions": 1,
|
|
"holdings_snapshot": json.dumps(holdings), # string form as returned by Supabase
|
|
}
|
|
snap = PortfolioSnapshot.from_dict(data)
|
|
assert snap.holdings_snapshot == holdings
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Holding.enrich()
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_holding_enrich_computes_current_value(sample_holding):
|
|
"""enrich() must set current_value = current_price * shares."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
|
assert sample_holding.current_value == 200.0 * sample_holding.shares
|
|
|
|
|
|
def test_holding_enrich_computes_unrealized_pnl(sample_holding):
|
|
"""enrich() must set unrealized_pnl = current_value - cost_basis."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
|
expected_cost_basis = sample_holding.avg_cost * sample_holding.shares
|
|
expected_pnl = sample_holding.current_value - expected_cost_basis
|
|
assert sample_holding.unrealized_pnl == pytest.approx(expected_pnl)
|
|
|
|
|
|
def test_holding_enrich_computes_unrealized_pnl_pct(sample_holding):
|
|
"""enrich() must set unrealized_pnl_pct = unrealized_pnl / cost_basis."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
|
cost_basis = sample_holding.avg_cost * sample_holding.shares
|
|
expected_pct = sample_holding.unrealized_pnl / cost_basis
|
|
assert sample_holding.unrealized_pnl_pct == pytest.approx(expected_pct)
|
|
|
|
|
|
def test_holding_enrich_computes_weight(sample_holding):
|
|
"""enrich() must set weight = current_value / portfolio_total_value."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
|
expected_weight = sample_holding.current_value / 100_000.0
|
|
assert sample_holding.weight == pytest.approx(expected_weight)
|
|
|
|
|
|
def test_holding_enrich_returns_self(sample_holding):
|
|
"""enrich() must return self for chaining."""
|
|
result = sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
|
assert result is sample_holding
|
|
|
|
|
|
def test_holding_enrich_handles_zero_cost(sample_holding):
|
|
"""When avg_cost == 0, unrealized_pnl_pct must be 0 (no ZeroDivisionError)."""
|
|
sample_holding.avg_cost = 0.0
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=100_000.0)
|
|
assert sample_holding.unrealized_pnl_pct == 0.0
|
|
|
|
|
|
def test_holding_enrich_handles_zero_portfolio_value(sample_holding):
|
|
"""When portfolio_total_value == 0, weight must be 0 (no ZeroDivisionError)."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=0.0)
|
|
assert sample_holding.weight == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Portfolio.enrich()
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_portfolio_enrich_computes_total_value(sample_portfolio, sample_holding):
|
|
"""Portfolio.enrich() must compute total_value = cash + sum(holding.current_value)."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
|
sample_portfolio.enrich([sample_holding])
|
|
expected_equity = 200.0 * sample_holding.shares
|
|
assert sample_portfolio.total_value == pytest.approx(sample_portfolio.cash + expected_equity)
|
|
|
|
|
|
def test_portfolio_enrich_computes_equity_value(sample_portfolio, sample_holding):
|
|
"""Portfolio.enrich() must set equity_value = sum(holding.current_value)."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
|
sample_portfolio.enrich([sample_holding])
|
|
assert sample_portfolio.equity_value == pytest.approx(200.0 * sample_holding.shares)
|
|
|
|
|
|
def test_portfolio_enrich_computes_cash_pct(sample_portfolio, sample_holding):
|
|
"""Portfolio.enrich() must compute cash_pct = cash / total_value."""
|
|
sample_holding.enrich(current_price=200.0, portfolio_total_value=1.0) # sets current_value; dummy total is overwritten by portfolio.enrich()
|
|
sample_portfolio.enrich([sample_holding])
|
|
expected_pct = sample_portfolio.cash / sample_portfolio.total_value
|
|
assert sample_portfolio.cash_pct == pytest.approx(expected_pct)
|
|
|
|
|
|
def test_portfolio_enrich_returns_self(sample_portfolio):
|
|
"""enrich() must return self for chaining."""
|
|
result = sample_portfolio.enrich([])
|
|
assert result is sample_portfolio
|
|
|
|
|
|
def test_portfolio_enrich_no_holdings(sample_portfolio):
|
|
"""Portfolio.enrich() with empty holdings: equity_value=0, total_value=cash."""
|
|
sample_portfolio.enrich([])
|
|
assert sample_portfolio.equity_value == 0.0
|
|
assert sample_portfolio.total_value == sample_portfolio.cash
|
|
assert sample_portfolio.cash_pct == 1.0
|