perf(portfolio): batch database writes during bulk SELL executions
Replaces the O(N) database operations in the `TradeExecutor`'s `execute_decisions` SELL loop with a single `batch_remove_holdings` call to the repository. The new repository method calculates updates in memory, resolves duplicate operations on the same ticker, and issues the updates via newly implemented `psycopg2.extras.execute_batch` routines on the `SupabaseClient`. Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
parent
a7b8c996f2
commit
1ed46937d7
|
|
@ -89,8 +89,13 @@ PRICES = {"AAPL": 150.0, "MSFT": 300.0}
|
|||
|
||||
|
||||
def test_execute_sell_success():
|
||||
"""Successful SELL calls remove_holding and is in executed_trades."""
|
||||
"""Successful SELL calls batch_remove_holdings and is in executed_trades."""
|
||||
repo = _make_repo()
|
||||
# Mock batch_remove_holdings to return a tuple of (executed, failed)
|
||||
repo.batch_remove_holdings.return_value = (
|
||||
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}],
|
||||
[]
|
||||
)
|
||||
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
||||
|
||||
decisions = {
|
||||
|
|
@ -99,7 +104,11 @@ def test_execute_sell_success():
|
|||
}
|
||||
result = executor.execute_decisions("p1", decisions, PRICES)
|
||||
|
||||
repo.remove_holding.assert_called_once_with("p1", "AAPL", 5.0, 150.0)
|
||||
repo.batch_remove_holdings.assert_called_once()
|
||||
args, kwargs = repo.batch_remove_holdings.call_args
|
||||
assert args[0] == "p1"
|
||||
assert args[1] == [{"ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}]
|
||||
|
||||
assert len(result["executed_trades"]) == 1
|
||||
assert result["executed_trades"][0]["action"] == "SELL"
|
||||
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
||||
|
|
@ -123,9 +132,12 @@ def test_execute_sell_missing_price():
|
|||
|
||||
|
||||
def test_execute_sell_insufficient_shares():
|
||||
"""SELL that raises InsufficientSharesError → failed_trade."""
|
||||
"""SELL that fails due to logic in batch_remove_holdings → failed_trade."""
|
||||
repo = _make_repo()
|
||||
repo.remove_holding.side_effect = InsufficientSharesError("Not enough shares")
|
||||
repo.batch_remove_holdings.return_value = (
|
||||
[],
|
||||
[{"action": "SELL", "ticker": "AAPL", "reason": "Hold 10.0 shares of AAPL, cannot sell 999.0"}]
|
||||
)
|
||||
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
||||
|
||||
decisions = {
|
||||
|
|
@ -135,7 +147,7 @@ def test_execute_sell_insufficient_shares():
|
|||
result = executor.execute_decisions("p1", decisions, PRICES)
|
||||
|
||||
assert len(result["failed_trades"]) == 1
|
||||
assert "Not enough shares" in result["failed_trades"][0]["reason"]
|
||||
assert "cannot sell 999.0" in result["failed_trades"][0]["reason"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -207,6 +219,10 @@ def test_execute_decisions_sells_before_buys():
|
|||
"""SELLs are always executed before BUYs."""
|
||||
portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0)
|
||||
repo = _make_repo(portfolio=portfolio)
|
||||
repo.batch_remove_holdings.return_value = (
|
||||
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Exit"}],
|
||||
[]
|
||||
)
|
||||
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
||||
|
||||
decisions = {
|
||||
|
|
@ -215,9 +231,9 @@ def test_execute_decisions_sells_before_buys():
|
|||
}
|
||||
executor.execute_decisions("p1", decisions, PRICES)
|
||||
|
||||
# Verify call order: remove_holding before add_holding
|
||||
call_order = [c[0] for c in repo.method_calls if c[0] in ("remove_holding", "add_holding")]
|
||||
assert call_order.index("remove_holding") < call_order.index("add_holding")
|
||||
# Verify call order: batch_remove_holdings before add_holding
|
||||
call_order = [c[0] for c in repo.method_calls if c[0] in ("batch_remove_holdings", "add_holding")]
|
||||
assert call_order.index("batch_remove_holdings") < call_order.index("add_holding")
|
||||
|
||||
|
||||
def test_execute_decisions_takes_snapshot():
|
||||
|
|
|
|||
|
|
@ -233,6 +233,118 @@ class PortfolioRepository:
|
|||
|
||||
return result
|
||||
|
||||
|
||||
def batch_remove_holdings(
|
||||
self,
|
||||
portfolio_id: str,
|
||||
sells: list[dict[str, Any]],
|
||||
trade_date: str,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Sell shares in batch and update portfolio cash and holdings.
|
||||
|
||||
Args:
|
||||
portfolio_id: Portfolio ID.
|
||||
sells: List of dicts with keys 'ticker', 'shares', 'price', 'rationale'.
|
||||
trade_date: The date to record the trades.
|
||||
|
||||
Returns:
|
||||
Tuple of (executed_trades, failed_trades).
|
||||
"""
|
||||
executed_trades = []
|
||||
failed_trades = []
|
||||
|
||||
if not sells:
|
||||
return executed_trades, failed_trades
|
||||
|
||||
# Pre-fetch portfolio and holdings once
|
||||
portfolio = self._client.get_portfolio(portfolio_id)
|
||||
current_holdings = {h.ticker.upper(): h for h in self._client.list_holdings(portfolio_id)}
|
||||
|
||||
holdings_to_upsert = {}
|
||||
tickers_to_delete = set()
|
||||
trades_to_record = []
|
||||
total_proceeds = 0.0
|
||||
|
||||
for sell in sells:
|
||||
ticker = sell["ticker"]
|
||||
shares = sell["shares"]
|
||||
price = sell["price"]
|
||||
rationale = sell.get("rationale")
|
||||
|
||||
existing = current_holdings.get(ticker.upper())
|
||||
if not existing:
|
||||
failed_trades.append({
|
||||
"action": "SELL",
|
||||
"ticker": ticker,
|
||||
"reason": f"No holding for {ticker} in portfolio {portfolio_id}",
|
||||
})
|
||||
continue
|
||||
|
||||
if existing.shares < shares:
|
||||
failed_trades.append({
|
||||
"action": "SELL",
|
||||
"ticker": ticker,
|
||||
"reason": f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}",
|
||||
})
|
||||
continue
|
||||
|
||||
proceeds = shares * price
|
||||
total_proceeds += proceeds
|
||||
|
||||
if existing.shares == shares:
|
||||
tickers_to_delete.add(ticker.upper())
|
||||
# If we previously marked it to upsert, remove it
|
||||
if ticker.upper() in holdings_to_upsert:
|
||||
del holdings_to_upsert[ticker.upper()]
|
||||
# Remove from local tracking
|
||||
del current_holdings[ticker.upper()]
|
||||
else:
|
||||
existing.shares -= shares
|
||||
holdings_to_upsert[ticker.upper()] = existing
|
||||
|
||||
trade = Trade(
|
||||
trade_id=str(uuid.uuid4()),
|
||||
portfolio_id=portfolio_id,
|
||||
ticker=ticker.upper(),
|
||||
action="SELL",
|
||||
shares=shares,
|
||||
price=price,
|
||||
total_value=proceeds,
|
||||
trade_date=trade_date,
|
||||
rationale=rationale,
|
||||
signal_source="pm_agent",
|
||||
)
|
||||
trades_to_record.append(trade)
|
||||
|
||||
executed_trades.append({
|
||||
"action": "SELL",
|
||||
"ticker": ticker,
|
||||
"shares": shares,
|
||||
"price": price,
|
||||
"rationale": rationale,
|
||||
"trade_date": trade_date,
|
||||
})
|
||||
|
||||
if not executed_trades:
|
||||
return executed_trades, failed_trades
|
||||
|
||||
try:
|
||||
# Apply database writes in batch
|
||||
if tickers_to_delete:
|
||||
self._client.batch_delete_holdings(portfolio_id, list(tickers_to_delete))
|
||||
if holdings_to_upsert:
|
||||
self._client.batch_upsert_holdings(list(holdings_to_upsert.values()))
|
||||
|
||||
portfolio.cash += total_proceeds
|
||||
self._client.update_portfolio(portfolio)
|
||||
|
||||
if trades_to_record:
|
||||
self._client.batch_record_trades(trades_to_record)
|
||||
except Exception as exc:
|
||||
raise PortfolioError(f"Batch write failed: {exc}") from exc
|
||||
|
||||
return executed_trades, failed_trades
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshots
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -222,6 +222,35 @@ class SupabaseClient:
|
|||
f"Holding not found: {ticker} in portfolio {portfolio_id}"
|
||||
)
|
||||
|
||||
|
||||
def batch_upsert_holdings(self, holdings: list[Holding]) -> None:
|
||||
if not holdings:
|
||||
return
|
||||
query = '''
|
||||
INSERT INTO holdings
|
||||
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
|
||||
DO UPDATE SET shares = EXCLUDED.shares,
|
||||
avg_cost = EXCLUDED.avg_cost,
|
||||
sector = EXCLUDED.sector,
|
||||
industry = EXCLUDED.industry
|
||||
'''
|
||||
params = [
|
||||
(h.holding_id or str(uuid.uuid4()), h.portfolio_id, h.ticker.upper(), h.shares, h.avg_cost, h.sector, h.industry)
|
||||
for h in holdings
|
||||
]
|
||||
with self._cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, query, params)
|
||||
|
||||
def batch_delete_holdings(self, portfolio_id: str, tickers: list[str]) -> None:
|
||||
if not tickers:
|
||||
return
|
||||
query = "DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s"
|
||||
params = [(portfolio_id, ticker.upper()) for ticker in tickers]
|
||||
with self._cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, query, params)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trades
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -266,6 +295,25 @@ class SupabaseClient:
|
|||
rows = cur.fetchall()
|
||||
return [self._row_to_trade(r) for r in rows]
|
||||
|
||||
|
||||
def batch_record_trades(self, trades: list[Trade]) -> None:
|
||||
if not trades:
|
||||
return
|
||||
query = '''
|
||||
INSERT INTO trades
|
||||
(trade_id, portfolio_id, ticker, action, shares, price,
|
||||
total_value, rationale, signal_source, metadata)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
'''
|
||||
params = [
|
||||
(t.trade_id or str(uuid.uuid4()), t.portfolio_id, t.ticker, t.action,
|
||||
t.shares, t.price, t.total_value, t.rationale, t.signal_source,
|
||||
json.dumps(t.metadata))
|
||||
for t in trades
|
||||
]
|
||||
with self._cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, query, params)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Snapshots
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from typing import Any
|
|||
|
||||
from tradingagents.portfolio.exceptions import (
|
||||
InsufficientCashError,
|
||||
InsufficientSharesError,
|
||||
PortfolioError,
|
||||
)
|
||||
from tradingagents.portfolio.risk_evaluator import check_constraints
|
||||
|
|
@ -78,6 +77,7 @@ class TradeExecutor:
|
|||
buys = decisions.get("buys") or []
|
||||
|
||||
# --- SELLs first (frees cash before BUYs; no constraint pre-flight for sells) ---
|
||||
sells_to_process = []
|
||||
for sell in sells:
|
||||
ticker = (sell.get("ticker") or "").upper()
|
||||
shares = float(sell.get("shares") or 0)
|
||||
|
|
@ -102,24 +102,34 @@ class TradeExecutor:
|
|||
logger.warning("execute_decisions: no price for %s — skipping SELL", ticker)
|
||||
continue
|
||||
|
||||
try:
|
||||
self.repo.remove_holding(portfolio_id, ticker, shares, price)
|
||||
executed_trades.append({
|
||||
"action": "SELL",
|
||||
sells_to_process.append({
|
||||
"ticker": ticker,
|
||||
"shares": shares,
|
||||
"price": price,
|
||||
"rationale": rationale,
|
||||
"trade_date": trade_date,
|
||||
})
|
||||
logger.info("SELL %s x %.2f @ %.2f", ticker, shares, price)
|
||||
except (InsufficientSharesError, PortfolioError) as exc:
|
||||
|
||||
if sells_to_process:
|
||||
try:
|
||||
executed, failed = self.repo.batch_remove_holdings(portfolio_id, sells_to_process, trade_date)
|
||||
executed_trades.extend(executed)
|
||||
for f in failed:
|
||||
failed_trades.append({
|
||||
"action": "SELL",
|
||||
"ticker": ticker,
|
||||
"ticker": f.get("ticker", "UNKNOWN"),
|
||||
"reason": f.get("reason", "Batch execution failed"),
|
||||
})
|
||||
logger.warning("SELL failed for %s: %s", f.get("ticker"), f.get("reason"))
|
||||
for e in executed:
|
||||
logger.info("SELL %s x %.2f @ %.2f", e["ticker"], e["shares"], e["price"])
|
||||
except PortfolioError as exc:
|
||||
logger.error("Batch sell execution failed: %s", exc)
|
||||
for s in sells_to_process:
|
||||
failed_trades.append({
|
||||
"action": "SELL",
|
||||
"ticker": s["ticker"],
|
||||
"reason": str(exc),
|
||||
})
|
||||
logger.warning("SELL failed for %s: %s", ticker, exc)
|
||||
|
||||
# --- BUYs second ---
|
||||
for buy in buys:
|
||||
|
|
|
|||
Loading…
Reference in New Issue