diff --git a/tests/portfolio/test_trade_executor.py b/tests/portfolio/test_trade_executor.py index 2c6ff455..9786f3a0 100644 --- a/tests/portfolio/test_trade_executor.py +++ b/tests/portfolio/test_trade_executor.py @@ -89,8 +89,13 @@ PRICES = {"AAPL": 150.0, "MSFT": 300.0} def test_execute_sell_success(): - """Successful SELL calls remove_holding and is in executed_trades.""" + """Successful SELL calls batch_remove_holdings and is in executed_trades.""" repo = _make_repo() + # Mock batch_remove_holdings to return a tuple of (executed, failed) + repo.batch_remove_holdings.return_value = ( + [{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}], + [] + ) executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG) decisions = { @@ -99,7 +104,11 @@ def test_execute_sell_success(): } result = executor.execute_decisions("p1", decisions, PRICES) - repo.remove_holding.assert_called_once_with("p1", "AAPL", 5.0, 150.0) + repo.batch_remove_holdings.assert_called_once() + args, kwargs = repo.batch_remove_holdings.call_args + assert args[0] == "p1" + assert args[1] == [{"ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Stop loss"}] + assert len(result["executed_trades"]) == 1 assert result["executed_trades"][0]["action"] == "SELL" assert result["executed_trades"][0]["ticker"] == "AAPL" @@ -123,9 +132,12 @@ def test_execute_sell_missing_price(): def test_execute_sell_insufficient_shares(): - """SELL that raises InsufficientSharesError → failed_trade.""" + """SELL that fails due to logic in batch_remove_holdings → failed_trade.""" repo = _make_repo() - repo.remove_holding.side_effect = InsufficientSharesError("Not enough shares") + repo.batch_remove_holdings.return_value = ( + [], + [{"action": "SELL", "ticker": "AAPL", "reason": "Hold 10.0 shares of AAPL, cannot sell 999.0"}] + ) executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG) decisions = { @@ -135,7 +147,7 @@ def test_execute_sell_insufficient_shares(): result = executor.execute_decisions("p1", decisions, PRICES) assert len(result["failed_trades"]) == 1 - assert "Not enough shares" in result["failed_trades"][0]["reason"] + assert "cannot sell 999.0" in result["failed_trades"][0]["reason"] # --------------------------------------------------------------------------- @@ -207,6 +219,10 @@ def test_execute_decisions_sells_before_buys(): """SELLs are always executed before BUYs.""" portfolio = _make_portfolio(cash=50_000.0, total_value=60_000.0) repo = _make_repo(portfolio=portfolio) + repo.batch_remove_holdings.return_value = ( + [{"action": "SELL", "ticker": "AAPL", "shares": 5.0, "price": 150.0, "rationale": "Exit"}], + [] + ) executor = TradeExecutor(repo=repo, config=_DEFAULT_CONFIG) decisions = { @@ -215,9 +231,9 @@ def test_execute_decisions_sells_before_buys(): } executor.execute_decisions("p1", decisions, PRICES) - # Verify call order: remove_holding before add_holding - call_order = [c[0] for c in repo.method_calls if c[0] in ("remove_holding", "add_holding")] - assert call_order.index("remove_holding") < call_order.index("add_holding") + # Verify call order: batch_remove_holdings before add_holding + call_order = [c[0] for c in repo.method_calls if c[0] in ("batch_remove_holdings", "add_holding")] + assert call_order.index("batch_remove_holdings") < call_order.index("add_holding") def test_execute_decisions_takes_snapshot(): diff --git a/tradingagents/portfolio/repository.py b/tradingagents/portfolio/repository.py index 0a556ef5..d39c1144 100644 --- a/tradingagents/portfolio/repository.py +++ b/tradingagents/portfolio/repository.py @@ -233,6 +233,118 @@ class PortfolioRepository: return result + + def batch_remove_holdings( + self, + portfolio_id: str, + sells: list[dict[str, Any]], + trade_date: str, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """Sell shares in batch and update portfolio cash and holdings. + + Args: + portfolio_id: Portfolio ID. + sells: List of dicts with keys 'ticker', 'shares', 'price', 'rationale'. + trade_date: The date to record the trades. + + Returns: + Tuple of (executed_trades, failed_trades). + """ + executed_trades = [] + failed_trades = [] + + if not sells: + return executed_trades, failed_trades + + # Pre-fetch portfolio and holdings once + portfolio = self._client.get_portfolio(portfolio_id) + current_holdings = {h.ticker.upper(): h for h in self._client.list_holdings(portfolio_id)} + + holdings_to_upsert = {} + tickers_to_delete = set() + trades_to_record = [] + total_proceeds = 0.0 + + for sell in sells: + ticker = sell["ticker"] + shares = sell["shares"] + price = sell["price"] + rationale = sell.get("rationale") + + existing = current_holdings.get(ticker.upper()) + if not existing: + failed_trades.append({ + "action": "SELL", + "ticker": ticker, + "reason": f"No holding for {ticker} in portfolio {portfolio_id}", + }) + continue + + if existing.shares < shares: + failed_trades.append({ + "action": "SELL", + "ticker": ticker, + "reason": f"Hold {existing.shares} shares of {ticker}, cannot sell {shares}", + }) + continue + + proceeds = shares * price + total_proceeds += proceeds + + if existing.shares == shares: + tickers_to_delete.add(ticker.upper()) + # If we previously marked it to upsert, remove it + if ticker.upper() in holdings_to_upsert: + del holdings_to_upsert[ticker.upper()] + # Remove from local tracking + del current_holdings[ticker.upper()] + else: + existing.shares -= shares + holdings_to_upsert[ticker.upper()] = existing + + trade = Trade( + trade_id=str(uuid.uuid4()), + portfolio_id=portfolio_id, + ticker=ticker.upper(), + action="SELL", + shares=shares, + price=price, + total_value=proceeds, + trade_date=trade_date, + rationale=rationale, + signal_source="pm_agent", + ) + trades_to_record.append(trade) + + executed_trades.append({ + "action": "SELL", + "ticker": ticker, + "shares": shares, + "price": price, + "rationale": rationale, + "trade_date": trade_date, + }) + + if not executed_trades: + return executed_trades, failed_trades + + try: + # Apply database writes in batch + if tickers_to_delete: + self._client.batch_delete_holdings(portfolio_id, list(tickers_to_delete)) + if holdings_to_upsert: + self._client.batch_upsert_holdings(list(holdings_to_upsert.values())) + + portfolio.cash += total_proceeds + self._client.update_portfolio(portfolio) + + if trades_to_record: + self._client.batch_record_trades(trades_to_record) + except Exception as exc: + raise PortfolioError(f"Batch write failed: {exc}") from exc + + return executed_trades, failed_trades + # ------------------------------------------------------------------ # Snapshots # ------------------------------------------------------------------ diff --git a/tradingagents/portfolio/supabase_client.py b/tradingagents/portfolio/supabase_client.py index c0107544..7c3aa0bc 100644 --- a/tradingagents/portfolio/supabase_client.py +++ b/tradingagents/portfolio/supabase_client.py @@ -222,6 +222,35 @@ class SupabaseClient: f"Holding not found: {ticker} in portfolio {portfolio_id}" ) + + def batch_upsert_holdings(self, holdings: list[Holding]) -> None: + if not holdings: + return + query = ''' + INSERT INTO holdings + (holding_id, portfolio_id, ticker, shares, avg_cost, sector, industry) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT ON CONSTRAINT holdings_portfolio_ticker_unique + DO UPDATE SET shares = EXCLUDED.shares, + avg_cost = EXCLUDED.avg_cost, + sector = EXCLUDED.sector, + industry = EXCLUDED.industry + ''' + params = [ + (h.holding_id or str(uuid.uuid4()), h.portfolio_id, h.ticker.upper(), h.shares, h.avg_cost, h.sector, h.industry) + for h in holdings + ] + with self._cursor() as cur: + psycopg2.extras.execute_batch(cur, query, params) + + def batch_delete_holdings(self, portfolio_id: str, tickers: list[str]) -> None: + if not tickers: + return + query = "DELETE FROM holdings WHERE portfolio_id = %s AND ticker = %s" + params = [(portfolio_id, ticker.upper()) for ticker in tickers] + with self._cursor() as cur: + psycopg2.extras.execute_batch(cur, query, params) + # ------------------------------------------------------------------ # Trades # ------------------------------------------------------------------ @@ -266,6 +295,25 @@ class SupabaseClient: rows = cur.fetchall() return [self._row_to_trade(r) for r in rows] + + def batch_record_trades(self, trades: list[Trade]) -> None: + if not trades: + return + query = ''' + INSERT INTO trades + (trade_id, portfolio_id, ticker, action, shares, price, + total_value, rationale, signal_source, metadata) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + ''' + params = [ + (t.trade_id or str(uuid.uuid4()), t.portfolio_id, t.ticker, t.action, + t.shares, t.price, t.total_value, t.rationale, t.signal_source, + json.dumps(t.metadata)) + for t in trades + ] + with self._cursor() as cur: + psycopg2.extras.execute_batch(cur, query, params) + # ------------------------------------------------------------------ # Snapshots # ------------------------------------------------------------------ diff --git a/tradingagents/portfolio/trade_executor.py b/tradingagents/portfolio/trade_executor.py index a11bd573..bd14bc98 100644 --- a/tradingagents/portfolio/trade_executor.py +++ b/tradingagents/portfolio/trade_executor.py @@ -13,7 +13,6 @@ from typing import Any from tradingagents.portfolio.exceptions import ( InsufficientCashError, - InsufficientSharesError, PortfolioError, ) from tradingagents.portfolio.risk_evaluator import check_constraints @@ -78,6 +77,7 @@ class TradeExecutor: buys = decisions.get("buys") or [] # --- SELLs first (frees cash before BUYs; no constraint pre-flight for sells) --- + sells_to_process = [] for sell in sells: ticker = (sell.get("ticker") or "").upper() shares = float(sell.get("shares") or 0) @@ -102,24 +102,34 @@ class TradeExecutor: logger.warning("execute_decisions: no price for %s — skipping SELL", ticker) continue + sells_to_process.append({ + "ticker": ticker, + "shares": shares, + "price": price, + "rationale": rationale, + }) + + if sells_to_process: try: - self.repo.remove_holding(portfolio_id, ticker, shares, price) - executed_trades.append({ - "action": "SELL", - "ticker": ticker, - "shares": shares, - "price": price, - "rationale": rationale, - "trade_date": trade_date, - }) - logger.info("SELL %s x %.2f @ %.2f", ticker, shares, price) - except (InsufficientSharesError, PortfolioError) as exc: - failed_trades.append({ - "action": "SELL", - "ticker": ticker, - "reason": str(exc), - }) - logger.warning("SELL failed for %s: %s", ticker, exc) + executed, failed = self.repo.batch_remove_holdings(portfolio_id, sells_to_process, trade_date) + executed_trades.extend(executed) + for f in failed: + failed_trades.append({ + "action": "SELL", + "ticker": f.get("ticker", "UNKNOWN"), + "reason": f.get("reason", "Batch execution failed"), + }) + logger.warning("SELL failed for %s: %s", f.get("ticker"), f.get("reason")) + for e in executed: + logger.info("SELL %s x %.2f @ %.2f", e["ticker"], e["shares"], e["price"]) + except PortfolioError as exc: + logger.error("Batch sell execution failed: %s", exc) + for s in sells_to_process: + failed_trades.append({ + "action": "SELL", + "ticker": s["ticker"], + "reason": str(exc), + }) # --- BUYs second --- for buy in buys: