perf(portfolio): batch database writes during bulk SELL executions

Replaces the O(N) database operations in the `TradeExecutor`'s
`execute_decisions` SELL loop with a single `batch_remove_holdings`
call to the repository. The new repository method calculates updates
in memory, resolves duplicate operations on the same ticker, and issues
the updates via newly implemented `psycopg2.extras.execute_batch`
routines on the `SupabaseClient`.

Co-authored-by: aguzererler <6199053+aguzererler@users.noreply.github.com>
This commit is contained in:
google-labs-jules[bot] 2026-03-21 20:06:45 +00:00
parent a7b8c996f2
commit 1ed46937d7
4 changed files with 212 additions and 26 deletions

View File

@ -89,8 +89,13 @@ PRICES = {"AAPL": 150.0, "MSFT": 300.0}
def test_execute_sell_success():
"""Successful SELL calls remove_holding and is in executed_trades."""
"""Successful SELL calls batch_remove_holdings and is in executed_trades."""
repo = _make_repo()
# Mock batch_remove_holdings to return a tuple of (executed, failed)
repo.batch_remove_holdings.return_value = (
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}],
[]
)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
@ -99,7 +104,11 @@ def test_execute_sell_success():
}
result = executor.execute_decisions("p1", decisions, PRICES)
repo.remove_holding.assert_called_once_with("p1", "AAPL", 5.0, 150.0)
repo.batch_remove_holdings.assert_called_once()
args, kwargs = repo.batch_remove_holdings.call_args
assert args[0] == "p1"
assert args[1] == [{"ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}]
assert len(result["executed_trades"]) == 1
assert result["executed_trades"][0]["action"] == "SELL"
assert result["executed_trades"][0]["ticker"] == "AAPL"
@ -123,9 +132,12 @@ def test_execute_sell_missing_price():
def test_execute_sell_insufficient_shares():
"""SELL that raises InsufficientSharesError → failed_trade."""
"""SELL that fails due to logic in batch_remove_holdings → failed_trade."""
repo = _make_repo()
repo.remove_holding.side_effect = InsufficientSharesError("Not enough shares")
repo.batch_remove_holdings.return_value = (
[],
[{"action": "SELL", "ticker": "AAPL", "reason": "Hold 10.0 shares of AAPL, cannot sell 999.0"}]
)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
@ -135,7 +147,7 @@ def test_execute_sell_insufficient_shares():
result = executor.execute_decisions("p1", decisions, PRICES)
assert len(result["failed_trades"]) == 1
assert "Not enough shares" in result["failed_trades"][0]["reason"]
assert "cannot sell 999.0" in result["failed_trades"][0]["reason"]
# ---------------------------------------------------------------------------
@ -207,6 +219,10 @@ def test_execute_decisions_sells_before_buys():
"""SELLs are always executed before BUYs."""
portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0)
repo = _make_repo(portfolio=portfolio)
repo.batch_remove_holdings.return_value = (
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Exit"}],
[]
)
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
decisions = {
@ -215,9 +231,9 @@ def test_execute_decisions_sells_before_buys():
}
executor.execute_decisions("p1", decisions, PRICES)
# Verify call order: remove_holding before add_holding
call_order = [c[0] for c in repo.method_calls if c[0] in ("remove_holding", "add_holding")]
assert call_order.index("remove_holding") < call_order.index("add_holding")
# Verify call order: batch_remove_holdings before add_holding
call_order = [c[0] for c in repo.method_calls if c[0] in ("batch_remove_holdings", "add_holding")]
assert call_order.index("batch_remove_holdings") < call_order.index("add_holding")
def test_execute_decisions_takes_snapshot():

View File

@ -233,6 +233,118 @@ class PortfolioRepository:
return result
def batch_remove_holdings(
self,
portfolio_id: str,
sells: list[dict[str, Any]],
trade_date: str,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Sell shares in batch and update portfolio cash and holdings.
Args:
portfolio_id: Portfolio ID.
sells: List of dicts with keys 'ticker', 'shares', 'price', 'rationale'.
trade_date: The date to record the trades.
Returns:
Tuple of (executed_trades, failed_trades).
"""
executed_trades = []
failed_trades = []
if not sells:
return executed_trades, failed_trades
# Pre-fetch portfolio and holdings once
portfolio = self._client.get_portfolio(portfolio_id)
current_holdings = {h.ticker.upper(): h for h in self._client.list_holdings(portfolio_id)}
holdings_to_upsert = {}
tickers_to_delete = set()
trades_to_record = []
total_proceeds = 0.0
for sell in sells:
ticker = sell["ticker"]
shares = sell["shares"]
price = sell["price"]
rationale = sell.get("rationale")
existing = current_holdings.get(ticker.upper())
if not existing:
failed_trades.append({
"action": "SELL",
"ticker": ticker,
"reason": f"No holding for {ticker} in portfolio {portfolio_id}",
})
continue
if existing.shares < shares:
failed_trades.append({
"action": "SELL",
"ticker": ticker,
"reason": f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}",
})
continue
proceeds = shares * price
total_proceeds += proceeds
if existing.shares == shares:
tickers_to_delete.add(ticker.upper())
# If we previously marked it to upsert, remove it
if ticker.upper() in holdings_to_upsert:
del holdings_to_upsert[ticker.upper()]
# Remove from local tracking
del current_holdings[ticker.upper()]
else:
existing.shares -= shares
holdings_to_upsert[ticker.upper()] = existing
trade = Trade(
trade_id=str(uuid.uuid4()),
portfolio_id=portfolio_id,
ticker=ticker.upper(),
action="SELL",
shares=shares,
price=price,
total_value=proceeds,
trade_date=trade_date,
rationale=rationale,
signal_source="pm_agent",
)
trades_to_record.append(trade)
executed_trades.append({
"action": "SELL",
"ticker": ticker,
"shares": shares,
"price": price,
"rationale": rationale,
"trade_date": trade_date,
})
if not executed_trades:
return executed_trades, failed_trades
try:
# Apply database writes in batch
if tickers_to_delete:
self._client.batch_delete_holdings(portfolio_id, list(tickers_to_delete))
if holdings_to_upsert:
self._client.batch_upsert_holdings(list(holdings_to_upsert.values()))
portfolio.cash += total_proceeds
self._client.update_portfolio(portfolio)
if trades_to_record:
self._client.batch_record_trades(trades_to_record)
except Exception as exc:
raise PortfolioError(f"Batch write failed: {exc}") from exc
return executed_trades, failed_trades
# ------------------------------------------------------------------
# Snapshots
# ------------------------------------------------------------------

View File

@ -222,6 +222,35 @@ class SupabaseClient:
f"Holding not found: {ticker} in portfolio {portfolio_id}"
)
def batch_upsert_holdings(self, holdings: list[Holding]) -> None:
if not holdings:
return
query = '''
INSERT INTO holdings
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
DO UPDATE SET shares = EXCLUDED.shares,
avg_cost = EXCLUDED.avg_cost,
sector = EXCLUDED.sector,
industry = EXCLUDED.industry
'''
params = [
(h.holding_id or str(uuid.uuid4()), h.portfolio_id, h.ticker.upper(), h.shares, h.avg_cost, h.sector, h.industry)
for h in holdings
]
with self._cursor() as cur:
psycopg2.extras.execute_batch(cur, query, params)
def batch_delete_holdings(self, portfolio_id: str, tickers: list[str]) -> None:
if not tickers:
return
query = "DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s"
params = [(portfolio_id, ticker.upper()) for ticker in tickers]
with self._cursor() as cur:
psycopg2.extras.execute_batch(cur, query, params)
# ------------------------------------------------------------------
# Trades
# ------------------------------------------------------------------
@ -266,6 +295,25 @@ class SupabaseClient:
rows = cur.fetchall()
return [self._row_to_trade(r) for r in rows]
def batch_record_trades(self, trades: list[Trade]) -> None:
if not trades:
return
query = '''
INSERT INTO trades
(trade_id, portfolio_id, ticker, action, shares, price,
total_value, rationale, signal_source, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
'''
params = [
(t.trade_id or str(uuid.uuid4()), t.portfolio_id, t.ticker, t.action,
t.shares, t.price, t.total_value, t.rationale, t.signal_source,
json.dumps(t.metadata))
for t in trades
]
with self._cursor() as cur:
psycopg2.extras.execute_batch(cur, query, params)
# ------------------------------------------------------------------
# Snapshots
# ------------------------------------------------------------------

View File

@ -13,7 +13,6 @@ from typing import Any
from tradingagents.portfolio.exceptions import (
InsufficientCashError,
InsufficientSharesError,
PortfolioError,
)
from tradingagents.portfolio.risk_evaluator import check_constraints
@ -78,6 +77,7 @@ class TradeExecutor:
buys = decisions.get("buys") or []
# --- SELLs first (frees cash before BUYs; no constraint pre-flight for sells) ---
sells_to_process = []
for sell in sells:
ticker = (sell.get("ticker") or "").upper()
shares = float(sell.get("shares") or 0)
@ -102,24 +102,34 @@ class TradeExecutor:
logger.warning("execute_decisions: no price for %s — skipping SELL", ticker)
continue
sells_to_process.append({
"ticker": ticker,
"shares": shares,
"price": price,
"rationale": rationale,
})
if sells_to_process:
try:
self.repo.remove_holding(portfolio_id, ticker, shares, price)
executed_trades.append({
"action": "SELL",
"ticker": ticker,
"shares": shares,
"price": price,
"rationale": rationale,
"trade_date": trade_date,
})
logger.info("SELL %s x %.2f @ %.2f", ticker, shares, price)
except (InsufficientSharesError, PortfolioError) as exc:
failed_trades.append({
"action": "SELL",
"ticker": ticker,
"reason": str(exc),
})
logger.warning("SELL failed for %s: %s", ticker, exc)
executed, failed = self.repo.batch_remove_holdings(portfolio_id, sells_to_process, trade_date)
executed_trades.extend(executed)
for f in failed:
failed_trades.append({
"action": "SELL",
"ticker": f.get("ticker", "UNKNOWN"),
"reason": f.get("reason", "Batch execution failed"),
})
logger.warning("SELL failed for %s: %s", f.get("ticker"), f.get("reason"))
for e in executed:
logger.info("SELL %s x %.2f @ %.2f", e["ticker"], e["shares"], e["price"])
except PortfolioError as exc:
logger.error("Batch sell execution failed: %s", exc)
for s in sells_to_process:
failed_trades.append({
"action": "SELL",
"ticker": s["ticker"],
"reason": str(exc),
})
# --- BUYs second ---
for buy in buys: