TradingAgents/tradingagents/memory/trade_history.py

821 lines
27 KiB
Python

"""Trade History Memory for learning from past trade outcomes.
This module provides specialized memory for tracking and learning from trades:
- Trade outcomes (profit/loss, returns)
- Agent reasoning and signals
- Entry/exit conditions
- Market context at time of trade
Issue #19: [MEM-18] Trade history memory - outcomes, agent reasoning
"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Any
import uuid
from .layered_memory import (
LayeredMemory,
MemoryEntry,
MemoryConfig,
ScoringWeights,
ImportanceLevel,
)
class TradeOutcome(Enum):
"""Trade outcome categories."""
PROFITABLE = "profitable" # Positive return
BREAK_EVEN = "break_even" # ~0% return
LOSS = "loss" # Negative return
STOPPED_OUT = "stopped_out" # Hit stop loss
TARGET_HIT = "target_hit" # Hit profit target
class TradeDirection(Enum):
"""Trade direction."""
LONG = "long"
SHORT = "short"
HOLD = "hold"
class SignalStrength(Enum):
"""Signal strength levels."""
STRONG_BUY = "strong_buy"
BUY = "buy"
NEUTRAL = "neutral"
SELL = "sell"
STRONG_SELL = "strong_sell"
@dataclass
class AgentReasoning:
"""Captures reasoning from each agent in the trading workflow.
Attributes:
fundamentals: Fundamentals analyst reasoning
technical: Technical/market analyst reasoning
news: News analyst reasoning
sentiment: Social media sentiment reasoning
momentum: Momentum analyst reasoning (if enabled)
macro: Macro analyst reasoning (if enabled)
correlation: Correlation analyst reasoning (if enabled)
bull_case: Bull researcher arguments
bear_case: Bear researcher arguments
research_conclusion: Research manager decision
risk_assessment: Risk manager assessment
final_signal: Final trading signal
"""
fundamentals: Optional[str] = None
technical: Optional[str] = None
news: Optional[str] = None
sentiment: Optional[str] = None
momentum: Optional[str] = None
macro: Optional[str] = None
correlation: Optional[str] = None
bull_case: Optional[str] = None
bear_case: Optional[str] = None
research_conclusion: Optional[str] = None
risk_assessment: Optional[str] = None
final_signal: Optional[str] = None
def to_dict(self) -> Dict[str, Optional[str]]:
"""Convert to dictionary."""
return {
"fundamentals": self.fundamentals,
"technical": self.technical,
"news": self.news,
"sentiment": self.sentiment,
"momentum": self.momentum,
"macro": self.macro,
"correlation": self.correlation,
"bull_case": self.bull_case,
"bear_case": self.bear_case,
"research_conclusion": self.research_conclusion,
"risk_assessment": self.risk_assessment,
"final_signal": self.final_signal,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentReasoning":
"""Create from dictionary."""
return cls(
fundamentals=data.get("fundamentals"),
technical=data.get("technical"),
news=data.get("news"),
sentiment=data.get("sentiment"),
momentum=data.get("momentum"),
macro=data.get("macro"),
correlation=data.get("correlation"),
bull_case=data.get("bull_case"),
bear_case=data.get("bear_case"),
research_conclusion=data.get("research_conclusion"),
risk_assessment=data.get("risk_assessment"),
final_signal=data.get("final_signal"),
)
def summary(self) -> str:
"""Generate a text summary of the reasoning."""
parts = []
if self.fundamentals:
parts.append(f"Fundamentals: {self.fundamentals[:100]}...")
if self.technical:
parts.append(f"Technical: {self.technical[:100]}...")
if self.bull_case:
parts.append(f"Bull: {self.bull_case[:100]}...")
if self.bear_case:
parts.append(f"Bear: {self.bear_case[:100]}...")
if self.research_conclusion:
parts.append(f"Conclusion: {self.research_conclusion[:100]}...")
return " | ".join(parts) if parts else "No reasoning recorded"
@dataclass
class MarketContext:
"""Market conditions at time of trade.
Attributes:
vix: VIX volatility index level
spy_return_1d: SPY 1-day return
sector_performance: Dict of sector returns
economic_regime: Detected economic regime
yield_curve_state: Yield curve status
macro_indicators: Key macro indicators
"""
vix: Optional[float] = None
spy_return_1d: Optional[float] = None
sector_performance: Dict[str, float] = field(default_factory=dict)
economic_regime: Optional[str] = None
yield_curve_state: Optional[str] = None
macro_indicators: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"vix": self.vix,
"spy_return_1d": self.spy_return_1d,
"sector_performance": self.sector_performance,
"economic_regime": self.economic_regime,
"yield_curve_state": self.yield_curve_state,
"macro_indicators": self.macro_indicators,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MarketContext":
"""Create from dictionary."""
return cls(
vix=data.get("vix"),
spy_return_1d=data.get("spy_return_1d"),
sector_performance=data.get("sector_performance", {}),
economic_regime=data.get("economic_regime"),
yield_curve_state=data.get("yield_curve_state"),
macro_indicators=data.get("macro_indicators", {}),
)
def summary(self) -> str:
"""Generate a text summary of market context."""
parts = []
if self.vix is not None:
parts.append(f"VIX: {self.vix:.1f}")
if self.economic_regime:
parts.append(f"Regime: {self.economic_regime}")
if self.yield_curve_state:
parts.append(f"Yield Curve: {self.yield_curve_state}")
return " | ".join(parts) if parts else "No market context"
@dataclass
class TradeRecord:
"""Complete record of a trade including reasoning and outcome.
Attributes:
id: Unique trade ID
symbol: Trading symbol
direction: Trade direction (long/short/hold)
entry_price: Entry price
exit_price: Exit price (if closed)
entry_time: Entry timestamp
exit_time: Exit timestamp (if closed)
quantity: Number of shares/contracts
returns: Percentage return
pnl: Dollar profit/loss
outcome: Trade outcome category
signal_strength: Original signal strength
confidence: Confidence score (0-1)
reasoning: Agent reasoning captured
market_context: Market conditions at entry
lessons_learned: Post-trade lessons (added later)
tags: Trade tags for filtering
"""
id: str
symbol: str
direction: TradeDirection
entry_price: float
entry_time: datetime
quantity: float = 1.0
exit_price: Optional[float] = None
exit_time: Optional[datetime] = None
returns: Optional[float] = None
pnl: Optional[float] = None
outcome: Optional[TradeOutcome] = None
signal_strength: SignalStrength = SignalStrength.NEUTRAL
confidence: float = 0.5
reasoning: AgentReasoning = field(default_factory=AgentReasoning)
market_context: MarketContext = field(default_factory=MarketContext)
lessons_learned: Optional[str] = None
tags: List[str] = field(default_factory=list)
@classmethod
def create(
cls,
symbol: str,
direction: TradeDirection,
entry_price: float,
quantity: float = 1.0,
signal_strength: SignalStrength = SignalStrength.NEUTRAL,
confidence: float = 0.5,
reasoning: Optional[AgentReasoning] = None,
market_context: Optional[MarketContext] = None,
tags: Optional[List[str]] = None,
) -> "TradeRecord":
"""Create a new trade record."""
return cls(
id=str(uuid.uuid4()),
symbol=symbol,
direction=direction,
entry_price=entry_price,
entry_time=datetime.now(),
quantity=quantity,
signal_strength=signal_strength,
confidence=confidence,
reasoning=reasoning or AgentReasoning(),
market_context=market_context or MarketContext(),
tags=tags or [],
)
def close(
self,
exit_price: float,
exit_time: Optional[datetime] = None,
) -> "TradeRecord":
"""Close the trade and calculate returns.
Args:
exit_price: Exit price
exit_time: Exit timestamp (default: now)
Returns:
Self with updated exit info
"""
self.exit_price = exit_price
self.exit_time = exit_time or datetime.now()
# Calculate returns
if self.direction == TradeDirection.LONG:
self.returns = (exit_price - self.entry_price) / self.entry_price
elif self.direction == TradeDirection.SHORT:
self.returns = (self.entry_price - exit_price) / self.entry_price
else:
self.returns = 0.0
# Calculate PnL
self.pnl = self.returns * self.entry_price * self.quantity
# Determine outcome
if self.returns > 0.005: # > 0.5%
self.outcome = TradeOutcome.PROFITABLE
elif self.returns < -0.005: # < -0.5%
self.outcome = TradeOutcome.LOSS
else:
self.outcome = TradeOutcome.BREAK_EVEN
return self
def is_open(self) -> bool:
"""Check if trade is still open."""
return self.exit_time is None
def holding_period_days(self) -> Optional[float]:
"""Calculate holding period in days."""
if self.exit_time is None:
return None
delta = self.exit_time - self.entry_time
return delta.total_seconds() / 86400
def to_memory_content(self) -> str:
"""Generate memory content for this trade."""
parts = [
f"Trade {self.direction.value} {self.symbol} at ${self.entry_price:.2f}",
]
if self.outcome:
parts.append(f"Outcome: {self.outcome.value}")
if self.returns is not None:
parts.append(f"Return: {self.returns * 100:.2f}%")
if self.reasoning.research_conclusion:
parts.append(f"Reasoning: {self.reasoning.research_conclusion}")
if self.market_context.economic_regime:
parts.append(f"Regime: {self.market_context.economic_regime}")
return " | ".join(parts)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return {
"id": self.id,
"symbol": self.symbol,
"direction": self.direction.value,
"entry_price": self.entry_price,
"exit_price": self.exit_price,
"entry_time": self.entry_time.isoformat(),
"exit_time": self.exit_time.isoformat() if self.exit_time else None,
"quantity": self.quantity,
"returns": self.returns,
"pnl": self.pnl,
"outcome": self.outcome.value if self.outcome else None,
"signal_strength": self.signal_strength.value,
"confidence": self.confidence,
"reasoning": self.reasoning.to_dict(),
"market_context": self.market_context.to_dict(),
"lessons_learned": self.lessons_learned,
"tags": self.tags,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TradeRecord":
"""Create from dictionary."""
return cls(
id=data["id"],
symbol=data["symbol"],
direction=TradeDirection(data["direction"]),
entry_price=data["entry_price"],
exit_price=data.get("exit_price"),
entry_time=datetime.fromisoformat(data["entry_time"]),
exit_time=datetime.fromisoformat(data["exit_time"]) if data.get("exit_time") else None,
quantity=data.get("quantity", 1.0),
returns=data.get("returns"),
pnl=data.get("pnl"),
outcome=TradeOutcome(data["outcome"]) if data.get("outcome") else None,
signal_strength=SignalStrength(data.get("signal_strength", "neutral")),
confidence=data.get("confidence", 0.5),
reasoning=AgentReasoning.from_dict(data.get("reasoning", {})),
market_context=MarketContext.from_dict(data.get("market_context", {})),
lessons_learned=data.get("lessons_learned"),
tags=data.get("tags", []),
)
class TradeHistoryMemory:
"""Memory system specialized for trade history and learning.
This class combines trade record storage with the layered memory system
for intelligent retrieval of past trades based on similarity.
Example:
>>> memory = TradeHistoryMemory()
>>>
>>> # Record a trade
>>> trade = TradeRecord.create(
... symbol="AAPL",
... direction=TradeDirection.LONG,
... entry_price=150.0,
... reasoning=AgentReasoning(
... fundamentals="Strong earnings",
... research_conclusion="Buy on earnings momentum",
... ),
... )
>>> memory.record_trade(trade)
>>>
>>> # Close the trade
>>> memory.close_trade(trade.id, exit_price=165.0)
>>>
>>> # Find similar past trades
>>> similar = memory.find_similar_trades(
... query="Apple earnings momentum",
... top_k=5,
... )
"""
def __init__(
self,
config: Optional[MemoryConfig] = None,
embedding_function=None,
):
"""Initialize trade history memory.
Args:
config: Memory configuration
embedding_function: Optional embedding function for similarity
"""
# Default config with weights tuned for trade history
if config is None:
config = MemoryConfig(
weights=ScoringWeights(
recency=0.25, # Recent trades somewhat important
relevancy=0.45, # Similarity most important
importance=0.30, # Outcome importance matters
),
)
self._layered_memory = LayeredMemory(
config=config,
embedding_function=embedding_function,
)
self._trades: Dict[str, TradeRecord] = {}
def record_trade(self, trade: TradeRecord) -> str:
"""Record a new trade.
Args:
trade: The trade record to store
Returns:
Trade ID
"""
self._trades[trade.id] = trade
# Create memory entry for the trade
importance = self._calculate_trade_importance(trade)
entry = MemoryEntry.create(
content=trade.to_memory_content(),
metadata={
"trade_id": trade.id,
"symbol": trade.symbol,
"direction": trade.direction.value,
"outcome": trade.outcome.value if trade.outcome else None,
"returns": trade.returns,
"reasoning_summary": trade.reasoning.summary(),
},
importance=importance,
tags=trade.tags + [trade.symbol, trade.direction.value],
timestamp=trade.entry_time,
)
entry.id = trade.id # Use trade ID as memory ID
self._layered_memory.add(entry)
return trade.id
def close_trade(
self,
trade_id: str,
exit_price: float,
lessons_learned: Optional[str] = None,
) -> Optional[TradeRecord]:
"""Close an open trade and update memory.
Args:
trade_id: ID of the trade to close
exit_price: Exit price
lessons_learned: Optional lessons from the trade
Returns:
Updated trade record or None if not found
"""
trade = self._trades.get(trade_id)
if trade is None:
return None
# Close the trade
trade.close(exit_price)
if lessons_learned:
trade.lessons_learned = lessons_learned
# Update memory with outcome
importance = self._calculate_trade_importance(trade)
self._layered_memory.update_importance(trade_id, importance)
# Update memory entry content
entry = self._layered_memory.get(trade_id)
if entry:
entry.metadata["outcome"] = trade.outcome.value if trade.outcome else None
entry.metadata["returns"] = trade.returns
if lessons_learned:
entry.metadata["lessons_learned"] = lessons_learned
return trade
def get_trade(self, trade_id: str) -> Optional[TradeRecord]:
"""Get a trade by ID.
Args:
trade_id: Trade ID
Returns:
Trade record or None
"""
return self._trades.get(trade_id)
def get_open_trades(self) -> List[TradeRecord]:
"""Get all open trades.
Returns:
List of open trade records
"""
return [t for t in self._trades.values() if t.is_open()]
def get_closed_trades(self) -> List[TradeRecord]:
"""Get all closed trades.
Returns:
List of closed trade records
"""
return [t for t in self._trades.values() if not t.is_open()]
def get_trades_by_symbol(self, symbol: str) -> List[TradeRecord]:
"""Get all trades for a symbol.
Args:
symbol: Trading symbol
Returns:
List of trade records for the symbol
"""
return [t for t in self._trades.values() if t.symbol == symbol]
def find_similar_trades(
self,
query: str,
top_k: int = 5,
symbol: Optional[str] = None,
outcome: Optional[TradeOutcome] = None,
) -> List[TradeRecord]:
"""Find similar past trades.
Args:
query: Query describing the current situation
top_k: Maximum number of results
symbol: Optional filter by symbol
outcome: Optional filter by outcome
Returns:
List of similar trade records
"""
# Build tags filter
tags = []
if symbol:
tags.append(symbol)
# Retrieve from layered memory
results = self._layered_memory.retrieve(
query=query,
top_k=top_k * 2, # Get more to filter
tags=tags if tags else None,
)
# Convert to trade records and filter
trades = []
for scored in results:
trade_id = scored.entry.metadata.get("trade_id")
if trade_id and trade_id in self._trades:
trade = self._trades[trade_id]
# Apply outcome filter
if outcome and trade.outcome != outcome:
continue
trades.append(trade)
if len(trades) >= top_k:
break
return trades
def find_profitable_patterns(
self,
query: str,
min_return: float = 0.0,
top_k: int = 5,
) -> List[TradeRecord]:
"""Find profitable trades similar to the query.
Args:
query: Query describing the current situation
min_return: Minimum return filter
top_k: Maximum number of results
Returns:
List of profitable trade records
"""
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
trades = []
for scored in results:
trade_id = scored.entry.metadata.get("trade_id")
if trade_id and trade_id in self._trades:
trade = self._trades[trade_id]
if trade.returns is not None and trade.returns >= min_return:
trades.append(trade)
if len(trades) >= top_k:
break
return trades
def find_losing_patterns(
self,
query: str,
max_return: float = 0.0,
top_k: int = 5,
) -> List[TradeRecord]:
"""Find losing trades similar to the query (for learning what to avoid).
Args:
query: Query describing the current situation
max_return: Maximum return filter
top_k: Maximum number of results
Returns:
List of losing trade records
"""
results = self._layered_memory.retrieve(query=query, top_k=top_k * 3)
trades = []
for scored in results:
trade_id = scored.entry.metadata.get("trade_id")
if trade_id and trade_id in self._trades:
trade = self._trades[trade_id]
if trade.returns is not None and trade.returns <= max_return:
trades.append(trade)
if len(trades) >= top_k:
break
return trades
def get_statistics(self) -> Dict[str, Any]:
"""Get trade history statistics.
Returns:
Dictionary with statistics
"""
closed_trades = self.get_closed_trades()
if not closed_trades:
return {
"total_trades": len(self._trades),
"open_trades": len(self.get_open_trades()),
"closed_trades": 0,
"win_rate": 0.0,
"avg_return": 0.0,
"total_pnl": 0.0,
"best_trade": None,
"worst_trade": None,
}
returns = [t.returns for t in closed_trades if t.returns is not None]
pnls = [t.pnl for t in closed_trades if t.pnl is not None]
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
# Find best and worst
best_trade = max(closed_trades, key=lambda t: t.returns or 0, default=None)
worst_trade = min(closed_trades, key=lambda t: t.returns or 0, default=None)
# Outcome distribution
outcome_dist = {}
for trade in closed_trades:
if trade.outcome:
outcome_dist[trade.outcome.value] = outcome_dist.get(trade.outcome.value, 0) + 1
return {
"total_trades": len(self._trades),
"open_trades": len(self.get_open_trades()),
"closed_trades": len(closed_trades),
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
"avg_return": sum(returns) / len(returns) if returns else 0.0,
"total_pnl": sum(pnls) if pnls else 0.0,
"best_trade": best_trade.to_dict() if best_trade else None,
"worst_trade": worst_trade.to_dict() if worst_trade else None,
"outcome_distribution": outcome_dist,
}
def get_symbol_statistics(self, symbol: str) -> Dict[str, Any]:
"""Get statistics for a specific symbol.
Args:
symbol: Trading symbol
Returns:
Dictionary with symbol-specific statistics
"""
symbol_trades = self.get_trades_by_symbol(symbol)
closed_trades = [t for t in symbol_trades if not t.is_open()]
if not closed_trades:
return {
"symbol": symbol,
"total_trades": len(symbol_trades),
"open_trades": len([t for t in symbol_trades if t.is_open()]),
"closed_trades": 0,
"win_rate": 0.0,
"avg_return": 0.0,
}
returns = [t.returns for t in closed_trades if t.returns is not None]
winners = [t for t in closed_trades if t.outcome == TradeOutcome.PROFITABLE]
return {
"symbol": symbol,
"total_trades": len(symbol_trades),
"open_trades": len([t for t in symbol_trades if t.is_open()]),
"closed_trades": len(closed_trades),
"win_rate": len(winners) / len(closed_trades) if closed_trades else 0.0,
"avg_return": sum(returns) / len(returns) if returns else 0.0,
}
def _calculate_trade_importance(self, trade: TradeRecord) -> float:
"""Calculate importance score for a trade.
Args:
trade: Trade record
Returns:
Importance score [0, 1]
"""
# Base on returns if closed
if trade.returns is not None:
abs_return = abs(trade.returns)
if abs_return >= 0.10: # 10%+
return ImportanceLevel.CRITICAL.value
elif abs_return >= 0.05: # 5%+
return ImportanceLevel.HIGH.value
elif abs_return >= 0.01: # 1%+
return ImportanceLevel.MEDIUM.value
else:
return ImportanceLevel.LOW.value
# For open trades, base on confidence
if trade.confidence >= 0.8:
return ImportanceLevel.HIGH.value
elif trade.confidence >= 0.5:
return ImportanceLevel.MEDIUM.value
else:
return ImportanceLevel.LOW.value
def count(self) -> int:
"""Return total number of trades."""
return len(self._trades)
def clear(self) -> int:
"""Clear all trades.
Returns:
Number of trades cleared
"""
count = len(self._trades)
self._trades.clear()
self._layered_memory.clear()
return count
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dictionary.
Returns:
Dictionary representation
"""
return {
"trades": [t.to_dict() for t in self._trades.values()],
"memory": self._layered_memory.to_dict(),
}
@classmethod
def from_dict(
cls,
data: Dict[str, Any],
embedding_function=None,
) -> "TradeHistoryMemory":
"""Create from dictionary.
Args:
data: Dictionary representation
embedding_function: Optional embedding function
Returns:
TradeHistoryMemory instance
"""
instance = cls(embedding_function=embedding_function)
# Restore trades
for trade_data in data.get("trades", []):
trade = TradeRecord.from_dict(trade_data)
instance._trades[trade.id] = trade
# Restore layered memory
if "memory" in data:
instance._layered_memory = LayeredMemory.from_dict(
data["memory"],
embedding_function=embedding_function,
)
return instance