"""Portfolio state models for swing trading.""" from dataclasses import dataclass, field from datetime import datetime from typing import Optional @dataclass class Position: """An open position in the portfolio.""" ticker: str market: str # "KRX" or "US" entry_date: str entry_price: float quantity: int stop_loss: float take_profit: float max_hold_days: int = 20 current_price: float = 0.0 screening_reason: str = "" @property def days_held(self) -> int: entry = datetime.strptime(self.entry_date, "%Y-%m-%d") return (datetime.now() - entry).days @property def unrealized_pnl(self) -> float: return (self.current_price - self.entry_price) * self.quantity @property def unrealized_pnl_pct(self) -> float: if self.entry_price == 0: return 0.0 return (self.current_price - self.entry_price) / self.entry_price * 100 @property def cost_basis(self) -> float: return self.entry_price * self.quantity def should_check_exit(self, current_date: str) -> bool: """Check if position needs exit evaluation (stop-loss, take-profit, or max hold).""" if self.current_price <= self.stop_loss: return True if self.current_price >= self.take_profit: return True if self.days_held >= self.max_hold_days: return True return False @dataclass class ClosedTrade: """A completed trade with realized P&L.""" ticker: str market: str entry_date: str exit_date: str entry_price: float exit_price: float quantity: int exit_reason: str # "stop_loss", "take_profit", "max_hold", "agent_decision" @property def pnl(self) -> float: return (self.exit_price - self.entry_price) * self.quantity @property def pnl_pct(self) -> float: if self.entry_price == 0: return 0.0 return (self.exit_price - self.entry_price) / self.entry_price * 100 @dataclass class Order: """A trading order generated by the system.""" action: str # "BUY", "SELL" ticker: str market: str price: float stop_loss: float take_profit: float quantity: int position_size_pct: float # % of total capital max_hold_days: int = 20 rationale: str = "" timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) @dataclass class PortfolioState: """Complete portfolio state for swing trading.""" portfolio_id: str = "default" total_capital: float = 100_000_000 # 1억원 default available_capital: float = 100_000_000 max_positions: int = 5 max_position_pct: float = 0.20 # 20% of total capital per position positions: dict[str, Position] = field(default_factory=dict) closed_trades: list[ClosedTrade] = field(default_factory=list) orders_history: list[Order] = field(default_factory=list) created_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) @property def invested_capital(self) -> float: return sum(p.cost_basis for p in self.positions.values()) @property def total_unrealized_pnl(self) -> float: return sum(p.unrealized_pnl for p in self.positions.values()) @property def total_realized_pnl(self) -> float: return sum(t.pnl for t in self.closed_trades) @property def position_count(self) -> int: return len(self.positions) def can_add_position(self) -> bool: return self.position_count < self.max_positions def available_slots(self) -> int: return self.max_positions - self.position_count def max_position_capital(self) -> float: return self.total_capital * self.max_position_pct def has_position(self, ticker: str) -> bool: return ticker in self.positions def add_position(self, order: Order) -> None: """Add a BUY order into portfolio, supporting DCA averaging for existing positions.""" added_cost = order.price * order.quantity if order.ticker in self.positions: # DCA: update weighted-average entry/levels for existing position. pos = self.positions[order.ticker] old_qty = pos.quantity new_qty = old_qty + order.quantity if new_qty > 0: pos.entry_price = ( (pos.entry_price * old_qty) + (order.price * order.quantity) ) / new_qty pos.stop_loss = ( (pos.stop_loss * old_qty) + (order.stop_loss * order.quantity) ) / new_qty pos.take_profit = ( (pos.take_profit * old_qty) + (order.take_profit * order.quantity) ) / new_qty pos.quantity = new_qty pos.max_hold_days = max(pos.max_hold_days, order.max_hold_days) pos.current_price = order.price if order.rationale: if pos.screening_reason: pos.screening_reason = f"{pos.screening_reason} | {order.rationale}" else: pos.screening_reason = order.rationale else: self.positions[order.ticker] = Position( ticker=order.ticker, market=order.market, entry_date=datetime.now().strftime("%Y-%m-%d"), entry_price=order.price, quantity=order.quantity, stop_loss=order.stop_loss, take_profit=order.take_profit, max_hold_days=order.max_hold_days, current_price=order.price, screening_reason=order.rationale, ) self.available_capital -= added_cost self.orders_history.append(order) self.updated_at = datetime.now().isoformat() def close_position(self, ticker: str, exit_price: float, exit_reason: str) -> Optional[ClosedTrade]: """Close an existing position and record the trade.""" if ticker not in self.positions: return None pos = self.positions.pop(ticker) trade = ClosedTrade( ticker=pos.ticker, market=pos.market, entry_date=pos.entry_date, exit_date=datetime.now().strftime("%Y-%m-%d"), entry_price=pos.entry_price, exit_price=exit_price, quantity=pos.quantity, exit_reason=exit_reason, ) self.closed_trades.append(trade) self.available_capital += exit_price * pos.quantity self.updated_at = datetime.now().isoformat() return trade def summary(self) -> str: """Generate a text summary of the portfolio for agent context.""" lines = [ f"=== 포트폴리오 현황 ===", f"총 자본: {self.total_capital:,.0f}", f"가용 자본: {self.available_capital:,.0f}", f"투자 중: {self.invested_capital:,.0f}", f"포지션: {self.position_count}/{self.max_positions}", f"미실현 손익: {self.total_unrealized_pnl:,.0f}", f"실현 손익: {self.total_realized_pnl:,.0f}", ] if self.positions: lines.append("\n--- 보유 종목 ---") for ticker, pos in self.positions.items(): lines.append( f" {ticker}: 수량 {pos.quantity:,} / " f"평단 {pos.entry_price:,.0f} / " f"현재가 {pos.current_price:,.0f} / " f"수익률 {pos.unrealized_pnl_pct:+.1f}% / " f"보유일 {pos.days_held}일 / " f"손절 {pos.stop_loss:,.0f} / 익절 {pos.take_profit:,.0f}" ) return "\n".join(lines)