Merge pull request #75 from aguzererler/optimize-sell-batching-17320376887335337420
⚡ Batch database writes for portfolio SELL operations
This commit is contained in:
commit
e0b882ed75
|
|
@ -89,8 +89,13 @@ PRICES = {"AAPL": 150.0, "MSFT": 300.0}
|
||||||
|
|
||||||
|
|
||||||
def test_execute_sell_success():
|
def test_execute_sell_success():
|
||||||
"""Successful SELL calls remove_holding and is in executed_trades."""
|
"""Successful SELL calls batch_remove_holdings and is in executed_trades."""
|
||||||
repo = _make_repo()
|
repo = _make_repo()
|
||||||
|
# Mock batch_remove_holdings to return a tuple of (executed, failed)
|
||||||
|
repo.batch_remove_holdings.return_value = (
|
||||||
|
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}],
|
||||||
|
[]
|
||||||
|
)
|
||||||
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
||||||
|
|
||||||
decisions = {
|
decisions = {
|
||||||
|
|
@ -99,7 +104,11 @@ def test_execute_sell_success():
|
||||||
}
|
}
|
||||||
result = executor.execute_decisions("p1", decisions, PRICES)
|
result = executor.execute_decisions("p1", decisions, PRICES)
|
||||||
|
|
||||||
repo.remove_holding.assert_called_once_with("p1", "AAPL", 5.0, 150.0)
|
repo.batch_remove_holdings.assert_called_once()
|
||||||
|
args, kwargs = repo.batch_remove_holdings.call_args
|
||||||
|
assert args[0] == "p1"
|
||||||
|
assert args[1] == [{"ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}]
|
||||||
|
|
||||||
assert len(result["executed_trades"]) == 1
|
assert len(result["executed_trades"]) == 1
|
||||||
assert result["executed_trades"][0]["action"] == "SELL"
|
assert result["executed_trades"][0]["action"] == "SELL"
|
||||||
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
||||||
|
|
@ -123,9 +132,12 @@ def test_execute_sell_missing_price():
|
||||||
|
|
||||||
|
|
||||||
def test_execute_sell_insufficient_shares():
|
def test_execute_sell_insufficient_shares():
|
||||||
"""SELL that raises InsufficientSharesError → failed_trade."""
|
"""SELL that fails due to logic in batch_remove_holdings → failed_trade."""
|
||||||
repo = _make_repo()
|
repo = _make_repo()
|
||||||
repo.remove_holding.side_effect = InsufficientSharesError("Not enough shares")
|
repo.batch_remove_holdings.return_value = (
|
||||||
|
[],
|
||||||
|
[{"action": "SELL", "ticker": "AAPL", "reason": "Hold 10.0 shares of AAPL, cannot sell 999.0"}]
|
||||||
|
)
|
||||||
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
||||||
|
|
||||||
decisions = {
|
decisions = {
|
||||||
|
|
@ -135,7 +147,7 @@ def test_execute_sell_insufficient_shares():
|
||||||
result = executor.execute_decisions("p1", decisions, PRICES)
|
result = executor.execute_decisions("p1", decisions, PRICES)
|
||||||
|
|
||||||
assert len(result["failed_trades"]) == 1
|
assert len(result["failed_trades"]) == 1
|
||||||
assert "Not enough shares" in result["failed_trades"][0]["reason"]
|
assert "cannot sell 999.0" in result["failed_trades"][0]["reason"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -207,6 +219,10 @@ def test_execute_decisions_sells_before_buys():
|
||||||
"""SELLs are always executed before BUYs."""
|
"""SELLs are always executed before BUYs."""
|
||||||
portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0)
|
portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0)
|
||||||
repo = _make_repo(portfolio=portfolio)
|
repo = _make_repo(portfolio=portfolio)
|
||||||
|
repo.batch_remove_holdings.return_value = (
|
||||||
|
[{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Exit"}],
|
||||||
|
[]
|
||||||
|
)
|
||||||
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG)
|
||||||
|
|
||||||
decisions = {
|
decisions = {
|
||||||
|
|
@ -215,9 +231,9 @@ def test_execute_decisions_sells_before_buys():
|
||||||
}
|
}
|
||||||
executor.execute_decisions("p1", decisions, PRICES)
|
executor.execute_decisions("p1", decisions, PRICES)
|
||||||
|
|
||||||
# Verify call order: remove_holding before add_holding
|
# Verify call order: batch_remove_holdings before add_holding
|
||||||
call_order = [c[0] for c in repo.method_calls if c[0] in ("remove_holding", "add_holding")]
|
call_order = [c[0] for c in repo.method_calls if c[0] in ("batch_remove_holdings", "add_holding")]
|
||||||
assert call_order.index("remove_holding") < call_order.index("add_holding")
|
assert call_order.index("batch_remove_holdings") < call_order.index("add_holding")
|
||||||
|
|
||||||
|
|
||||||
def test_execute_decisions_takes_snapshot():
|
def test_execute_decisions_takes_snapshot():
|
||||||
|
|
|
||||||
|
|
@ -233,6 +233,118 @@ class PortfolioRepository:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def batch_remove_holdings(
|
||||||
|
self,
|
||||||
|
portfolio_id: str,
|
||||||
|
sells: list[dict[str, Any]],
|
||||||
|
trade_date: str,
|
||||||
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
"""Sell shares in batch and update portfolio cash and holdings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
portfolio_id: Portfolio ID.
|
||||||
|
sells: List of dicts with keys 'ticker', 'shares', 'price', 'rationale'.
|
||||||
|
trade_date: The date to record the trades.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (executed_trades, failed_trades).
|
||||||
|
"""
|
||||||
|
executed_trades = []
|
||||||
|
failed_trades = []
|
||||||
|
|
||||||
|
if not sells:
|
||||||
|
return executed_trades, failed_trades
|
||||||
|
|
||||||
|
# Pre-fetch portfolio and holdings once
|
||||||
|
portfolio = self._client.get_portfolio(portfolio_id)
|
||||||
|
current_holdings = {h.ticker.upper(): h for h in self._client.list_holdings(portfolio_id)}
|
||||||
|
|
||||||
|
holdings_to_upsert = {}
|
||||||
|
tickers_to_delete = set()
|
||||||
|
trades_to_record = []
|
||||||
|
total_proceeds = 0.0
|
||||||
|
|
||||||
|
for sell in sells:
|
||||||
|
ticker = sell["ticker"]
|
||||||
|
shares = sell["shares"]
|
||||||
|
price = sell["price"]
|
||||||
|
rationale = sell.get("rationale")
|
||||||
|
|
||||||
|
existing = current_holdings.get(ticker.upper())
|
||||||
|
if not existing:
|
||||||
|
failed_trades.append({
|
||||||
|
"action": "SELL",
|
||||||
|
"ticker": ticker,
|
||||||
|
"reason": f"No holding for {ticker} in portfolio {portfolio_id}",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if existing.shares < shares:
|
||||||
|
failed_trades.append({
|
||||||
|
"action": "SELL",
|
||||||
|
"ticker": ticker,
|
||||||
|
"reason": f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
proceeds = shares * price
|
||||||
|
total_proceeds += proceeds
|
||||||
|
|
||||||
|
if existing.shares == shares:
|
||||||
|
tickers_to_delete.add(ticker.upper())
|
||||||
|
# If we previously marked it to upsert, remove it
|
||||||
|
if ticker.upper() in holdings_to_upsert:
|
||||||
|
del holdings_to_upsert[ticker.upper()]
|
||||||
|
# Remove from local tracking
|
||||||
|
del current_holdings[ticker.upper()]
|
||||||
|
else:
|
||||||
|
existing.shares -= shares
|
||||||
|
holdings_to_upsert[ticker.upper()] = existing
|
||||||
|
|
||||||
|
trade = Trade(
|
||||||
|
trade_id=str(uuid.uuid4()),
|
||||||
|
portfolio_id=portfolio_id,
|
||||||
|
ticker=ticker.upper(),
|
||||||
|
action="SELL",
|
||||||
|
shares=shares,
|
||||||
|
price=price,
|
||||||
|
total_value=proceeds,
|
||||||
|
trade_date=trade_date,
|
||||||
|
rationale=rationale,
|
||||||
|
signal_source="pm_agent",
|
||||||
|
)
|
||||||
|
trades_to_record.append(trade)
|
||||||
|
|
||||||
|
executed_trades.append({
|
||||||
|
"action": "SELL",
|
||||||
|
"ticker": ticker,
|
||||||
|
"shares": shares,
|
||||||
|
"price": price,
|
||||||
|
"rationale": rationale,
|
||||||
|
"trade_date": trade_date,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not executed_trades:
|
||||||
|
return executed_trades, failed_trades
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply database writes in batch
|
||||||
|
if tickers_to_delete:
|
||||||
|
self._client.batch_delete_holdings(portfolio_id, list(tickers_to_delete))
|
||||||
|
if holdings_to_upsert:
|
||||||
|
self._client.batch_upsert_holdings(list(holdings_to_upsert.values()))
|
||||||
|
|
||||||
|
portfolio.cash += total_proceeds
|
||||||
|
self._client.update_portfolio(portfolio)
|
||||||
|
|
||||||
|
if trades_to_record:
|
||||||
|
self._client.batch_record_trades(trades_to_record)
|
||||||
|
except Exception as exc:
|
||||||
|
raise PortfolioError(f"Batch write failed: {exc}") from exc
|
||||||
|
|
||||||
|
return executed_trades, failed_trades
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Snapshots
|
# Snapshots
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -222,6 +222,35 @@ class SupabaseClient:
|
||||||
f"Holding not found: {ticker} in portfolio {portfolio_id}"
|
f"Holding not found: {ticker} in portfolio {portfolio_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_upsert_holdings(self, holdings: list[Holding]) -> None:
|
||||||
|
if not holdings:
|
||||||
|
return
|
||||||
|
query = '''
|
||||||
|
INSERT INTO holdings
|
||||||
|
(holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique
|
||||||
|
DO UPDATE SET shares = EXCLUDED.shares,
|
||||||
|
avg_cost = EXCLUDED.avg_cost,
|
||||||
|
sector = EXCLUDED.sector,
|
||||||
|
industry = EXCLUDED.industry
|
||||||
|
'''
|
||||||
|
params = [
|
||||||
|
(h.holding_id or str(uuid.uuid4()), h.portfolio_id, h.ticker.upper(), h.shares, h.avg_cost, h.sector, h.industry)
|
||||||
|
for h in holdings
|
||||||
|
]
|
||||||
|
with self._cursor() as cur:
|
||||||
|
psycopg2.extras.execute_batch(cur, query, params)
|
||||||
|
|
||||||
|
def batch_delete_holdings(self, portfolio_id: str, tickers: list[str]) -> None:
|
||||||
|
if not tickers:
|
||||||
|
return
|
||||||
|
query = "DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s"
|
||||||
|
params = [(portfolio_id, ticker.upper()) for ticker in tickers]
|
||||||
|
with self._cursor() as cur:
|
||||||
|
psycopg2.extras.execute_batch(cur, query, params)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Trades
|
# Trades
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -266,6 +295,25 @@ class SupabaseClient:
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
return [self._row_to_trade(r) for r in rows]
|
return [self._row_to_trade(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def batch_record_trades(self, trades: list[Trade]) -> None:
|
||||||
|
if not trades:
|
||||||
|
return
|
||||||
|
query = '''
|
||||||
|
INSERT INTO trades
|
||||||
|
(trade_id, portfolio_id, ticker, action, shares, price,
|
||||||
|
total_value, rationale, signal_source, metadata)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
'''
|
||||||
|
params = [
|
||||||
|
(t.trade_id or str(uuid.uuid4()), t.portfolio_id, t.ticker, t.action,
|
||||||
|
t.shares, t.price, t.total_value, t.rationale, t.signal_source,
|
||||||
|
json.dumps(t.metadata))
|
||||||
|
for t in trades
|
||||||
|
]
|
||||||
|
with self._cursor() as cur:
|
||||||
|
psycopg2.extras.execute_batch(cur, query, params)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Snapshots
|
# Snapshots
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from typing import Any
|
||||||
|
|
||||||
from tradingagents.portfolio.exceptions import (
|
from tradingagents.portfolio.exceptions import (
|
||||||
InsufficientCashError,
|
InsufficientCashError,
|
||||||
InsufficientSharesError,
|
|
||||||
PortfolioError,
|
PortfolioError,
|
||||||
)
|
)
|
||||||
from tradingagents.portfolio.risk_evaluator import check_constraints
|
from tradingagents.portfolio.risk_evaluator import check_constraints
|
||||||
|
|
@ -78,6 +77,7 @@ class TradeExecutor:
|
||||||
buys = decisions.get("buys") or []
|
buys = decisions.get("buys") or []
|
||||||
|
|
||||||
# --- SELLs first (frees cash before BUYs; no constraint pre-flight for sells) ---
|
# --- SELLs first (frees cash before BUYs; no constraint pre-flight for sells) ---
|
||||||
|
sells_to_process = []
|
||||||
for sell in sells:
|
for sell in sells:
|
||||||
ticker = (sell.get("ticker") or "").upper()
|
ticker = (sell.get("ticker") or "").upper()
|
||||||
shares = float(sell.get("shares") or 0)
|
shares = float(sell.get("shares") or 0)
|
||||||
|
|
@ -102,24 +102,34 @@ class TradeExecutor:
|
||||||
logger.warning("execute_decisions: no price for %s — skipping SELL", ticker)
|
logger.warning("execute_decisions: no price for %s — skipping SELL", ticker)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
sells_to_process.append({
|
||||||
self.repo.remove_holding(portfolio_id, ticker, shares, price)
|
|
||||||
executed_trades.append({
|
|
||||||
"action": "SELL",
|
|
||||||
"ticker": ticker,
|
"ticker": ticker,
|
||||||
"shares": shares,
|
"shares": shares,
|
||||||
"price": price,
|
"price": price,
|
||||||
"rationale": rationale,
|
"rationale": rationale,
|
||||||
"trade_date": trade_date,
|
|
||||||
})
|
})
|
||||||
logger.info("SELL %s x %.2f @ %.2f", ticker, shares, price)
|
|
||||||
except (InsufficientSharesError, PortfolioError) as exc:
|
if sells_to_process:
|
||||||
|
try:
|
||||||
|
executed, failed = self.repo.batch_remove_holdings(portfolio_id, sells_to_process, trade_date)
|
||||||
|
executed_trades.extend(executed)
|
||||||
|
for f in failed:
|
||||||
failed_trades.append({
|
failed_trades.append({
|
||||||
"action": "SELL",
|
"action": "SELL",
|
||||||
"ticker": ticker,
|
"ticker": f.get("ticker", "UNKNOWN"),
|
||||||
|
"reason": f.get("reason", "Batch execution failed"),
|
||||||
|
})
|
||||||
|
logger.warning("SELL failed for %s: %s", f.get("ticker"), f.get("reason"))
|
||||||
|
for e in executed:
|
||||||
|
logger.info("SELL %s x %.2f @ %.2f", e["ticker"], e["shares"], e["price"])
|
||||||
|
except PortfolioError as exc:
|
||||||
|
logger.error("Batch sell execution failed: %s", exc)
|
||||||
|
for s in sells_to_process:
|
||||||
|
failed_trades.append({
|
||||||
|
"action": "SELL",
|
||||||
|
"ticker": s["ticker"],
|
||||||
"reason": str(exc),
|
"reason": str(exc),
|
||||||
})
|
})
|
||||||
logger.warning("SELL failed for %s: %s", ticker, exc)
|
|
||||||
|
|
||||||
# --- BUYs second ---
|
# --- BUYs second ---
|
||||||
for buy in buys:
|
for buy in buys:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue