TradingAgents/tradingagents/memory/integration.py

601 lines
20 KiB
Python

"""Memory integration for agent prompts.
This module provides integration between the memory system and agent prompts,
enabling agents to access relevant historical context for better decision-making.
Issue #21: [MEM-20] Memory integration - retrieval in agent prompts
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
from .layered_memory import LayeredMemory, MemoryEntry, MemoryConfig, ScoringWeights
from .trade_history import TradeHistoryMemory, TradeRecord, TradeOutcome, TradeDirection
from .risk_profiles import (
RiskProfileMemory,
RiskProfile,
RiskDecision,
RiskTolerance,
MarketRegime,
RiskCategory,
)
class ContextType(Enum):
"""Types of memory context that can be retrieved."""
TRADE_HISTORY = "trade_history"
RISK_PROFILE = "risk_profile"
SIMILAR_SITUATIONS = "similar_situations"
LESSONS_LEARNED = "lessons_learned"
ALL = "all"
@dataclass
class MemoryContext:
"""Memory context for agent prompts.
Attributes:
trade_history: Summary of relevant past trades
risk_context: Risk profile recommendations
similar_situations: Similar past situations and outcomes
lessons_learned: Key lessons from past trades
raw_trades: List of relevant TradeRecord objects
"""
trade_history: str = ""
risk_context: str = ""
similar_situations: str = ""
lessons_learned: str = ""
raw_trades: List[TradeRecord] = field(default_factory=list)
def to_prompt_string(self, include_types: Optional[List[ContextType]] = None) -> str:
"""Convert memory context to a string for agent prompts.
Args:
include_types: Types of context to include (default: all non-empty)
Returns:
Formatted string for agent prompts
"""
if include_types is None:
include_types = [ContextType.ALL]
parts = []
if ContextType.ALL in include_types or ContextType.TRADE_HISTORY in include_types:
if self.trade_history:
parts.append(f"## Recent Trade History\n{self.trade_history}")
if ContextType.ALL in include_types or ContextType.RISK_PROFILE in include_types:
if self.risk_context:
parts.append(f"## Risk Profile Context\n{self.risk_context}")
if ContextType.ALL in include_types or ContextType.SIMILAR_SITUATIONS in include_types:
if self.similar_situations:
parts.append(f"## Similar Past Situations\n{self.similar_situations}")
if ContextType.ALL in include_types or ContextType.LESSONS_LEARNED in include_types:
if self.lessons_learned:
parts.append(f"## Lessons Learned\n{self.lessons_learned}")
if not parts:
return "No relevant memory context available."
return "\n\n".join(parts)
def is_empty(self) -> bool:
"""Check if context is empty."""
return not any([
self.trade_history,
self.risk_context,
self.similar_situations,
self.lessons_learned,
])
class AgentMemoryIntegration:
"""Integration layer between memory systems and agent prompts.
This class provides methods to retrieve relevant memory context
for different agents in the trading system.
Example:
>>> integration = AgentMemoryIntegration()
>>>
>>> # Get context for an analyst
>>> context = integration.get_analyst_context(
... ticker="AAPL",
... current_situation="Tech sector showing momentum",
... analyst_type="momentum",
... )
>>>
>>> # Use in prompt
>>> prompt = f"Analyze {ticker}. Memory context: {context.to_prompt_string()}"
"""
def __init__(
self,
trade_memory: Optional[TradeHistoryMemory] = None,
risk_memory: Optional[RiskProfileMemory] = None,
situation_memory: Optional[LayeredMemory] = None,
embedding_function: Optional[Callable] = None,
):
"""Initialize memory integration.
Args:
trade_memory: Trade history memory instance
risk_memory: Risk profile memory instance
situation_memory: General situation memory
embedding_function: Optional embedding function for similarity
"""
self._trade_memory = trade_memory or TradeHistoryMemory()
self._risk_memory = risk_memory or RiskProfileMemory()
self._situation_memory = situation_memory or LayeredMemory(
config=MemoryConfig(
weights=ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2)
),
embedding_function=embedding_function,
)
self._embedding_function = embedding_function
@property
def trade_memory(self) -> TradeHistoryMemory:
"""Access trade history memory."""
return self._trade_memory
@property
def risk_memory(self) -> RiskProfileMemory:
"""Access risk profile memory."""
return self._risk_memory
@property
def situation_memory(self) -> LayeredMemory:
"""Access situation memory."""
return self._situation_memory
def get_analyst_context(
self,
symbol: str,
current_situation: str,
analyst_type: str = "general",
lookback_days: int = 90,
max_trades: int = 5,
user_id: Optional[str] = None,
) -> MemoryContext:
"""Get memory context for an analyst agent.
Args:
symbol: Stock symbol being analyzed
current_situation: Current market situation description
analyst_type: Type of analyst (momentum, macro, etc.)
lookback_days: Days to look back for trades
max_trades: Maximum trades to include
user_id: User ID for risk profile
Returns:
MemoryContext with relevant information
"""
context = MemoryContext()
# Get relevant past trades for this symbol
trades = self._trade_memory.get_trades_by_symbol(symbol)
if trades:
# Filter to lookback period
cutoff = datetime.now() - timedelta(days=lookback_days)
recent_trades = [t for t in trades if t.entry_time >= cutoff][:max_trades]
if recent_trades:
context.raw_trades = recent_trades
context.trade_history = self._format_trade_history(recent_trades)
context.lessons_learned = self._extract_lessons(recent_trades)
# Get similar situations
similar = self._situation_memory.retrieve(
query=current_situation,
top_k=3,
tags=[analyst_type] if analyst_type != "general" else None,
)
if similar:
context.similar_situations = self._format_similar_situations(similar)
return context
def get_trader_context(
self,
symbol: str,
current_situation: str,
proposed_action: str,
market_regime: Optional[MarketRegime] = None,
user_id: Optional[str] = None,
) -> MemoryContext:
"""Get memory context for the trader agent.
Args:
symbol: Stock symbol being traded
current_situation: Current market situation
proposed_action: Proposed trade action (buy/sell/hold)
market_regime: Current market regime
user_id: User ID for risk profile
Returns:
MemoryContext with relevant information
"""
context = MemoryContext()
# Get past trades for this symbol
trades = self._trade_memory.get_trades_by_symbol(symbol)
if trades:
context.raw_trades = trades[:5]
context.trade_history = self._format_trade_history(trades[:5])
context.lessons_learned = self._extract_lessons(trades)
# Get risk profile context
if market_regime:
risk_level, explanation = self._risk_memory.recommend_risk_level(
category=RiskCategory.POSITION_SIZE,
market_regime=market_regime,
context=current_situation,
user_id=user_id,
)
profile = self._risk_memory.get_or_create_profile(user_id)
context.risk_context = (
f"Recommended risk level: {risk_level:.2f}\n"
f"Base tolerance: {profile.base_tolerance.value}\n"
f"Reasoning: {explanation}"
)
# Get similar trading situations
similar = self._situation_memory.retrieve(
query=f"{current_situation} {proposed_action}",
top_k=3,
)
if similar:
context.similar_situations = self._format_similar_situations(similar)
return context
def get_risk_manager_context(
self,
symbol: str,
proposed_trade: str,
position_size: float,
market_regime: Optional[MarketRegime] = None,
user_id: Optional[str] = None,
) -> MemoryContext:
"""Get memory context for risk management agent.
Args:
symbol: Stock symbol
proposed_trade: Proposed trade description
position_size: Proposed position size
market_regime: Current market regime
user_id: User ID
Returns:
MemoryContext with risk-focused information
"""
context = MemoryContext()
# Get past trades with outcome statistics
trades = self._trade_memory.get_trades_by_symbol(symbol)
if trades:
winning = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE]
losing = [t for t in trades if t.outcome == TradeOutcome.LOSS]
win_rate = len(winning) / len(trades) if trades else 0
avg_return = sum(
t.returns or 0 for t in trades if t.returns
) / max(1, len([t for t in trades if t.returns]))
context.trade_history = (
f"Trading history for {symbol}:\n"
f"- Total trades: {len(trades)}\n"
f"- Win rate: {win_rate:.1%}\n"
f"- Average return: {avg_return:.2%}\n"
f"- Winners: {len(winning)}, Losers: {len(losing)}"
)
# Extract risk lessons
context.lessons_learned = self._extract_risk_lessons(trades)
# Get risk profile and recommendations
if market_regime:
profile = self._risk_memory.get_or_create_profile(user_id)
adjusted_tolerance = profile.get_adjusted_tolerance(market_regime)
risk_level, explanation = self._risk_memory.recommend_risk_level(
category=RiskCategory.POSITION_SIZE,
market_regime=market_regime,
context=proposed_trade,
user_id=user_id,
)
context.risk_context = (
f"User risk profile:\n"
f"- Base tolerance: {profile.base_tolerance.value}\n"
f"- Adjusted for {market_regime.value}: {adjusted_tolerance.value}\n"
f"- Max drawdown tolerance: {profile.max_drawdown_tolerance:.1%}\n"
f"- Recommended risk level: {risk_level:.2f}\n"
f"- Reasoning: {explanation}"
)
return context
def record_trade_outcome(
self,
trade: TradeRecord,
situation_context: str,
lesson_learned: Optional[str] = None,
) -> None:
"""Record a trade outcome for future reference.
Args:
trade: The completed trade record
situation_context: Description of the market situation
lesson_learned: Optional lesson to remember
"""
# Record in trade memory
self._trade_memory.record_trade(trade)
# Record situation for future similarity matching
importance = 0.5
if trade.returns:
# Higher importance for significant outcomes
importance = min(1.0, 0.5 + abs(trade.returns))
entry_content = (
f"Trade: {trade.direction.value} {trade.symbol} at {trade.entry_price}. "
f"Outcome: {trade.outcome.value if trade.outcome else 'pending'}. "
f"Context: {situation_context}"
)
if lesson_learned:
entry_content += f"\nLesson: {lesson_learned}"
entry = MemoryEntry.create(
content=entry_content,
metadata={
"trade_id": trade.id,
"symbol": trade.symbol,
"direction": trade.direction.value,
"outcome": trade.outcome.value if trade.outcome else None,
"return": trade.returns,
},
importance=importance,
tags=[trade.symbol, trade.direction.value],
)
self._situation_memory.add(entry)
def record_risk_decision(
self,
category: RiskCategory,
risk_level: float,
market_regime: MarketRegime,
context: str,
user_id: Optional[str] = None,
) -> str:
"""Record a risk decision for learning.
Args:
category: Risk category
risk_level: Risk level chosen
market_regime: Current market regime
context: Decision context
user_id: User ID
Returns:
Decision ID
"""
decision = RiskDecision.create(
category=category,
risk_level=risk_level,
market_regime=market_regime,
context=context,
)
return self._risk_memory.record_decision(decision, user_id)
def evaluate_risk_decision(
self,
decision_id: str,
outcome: str,
outcome_score: float,
was_appropriate: bool,
) -> None:
"""Evaluate a past risk decision.
Args:
decision_id: Decision ID to evaluate
outcome: What happened
outcome_score: Outcome score (-1 to 1)
was_appropriate: Whether decision was appropriate
"""
self._risk_memory.evaluate_decision(
decision_id=decision_id,
outcome=outcome,
outcome_score=outcome_score,
was_appropriate=was_appropriate,
)
def _format_trade_history(self, trades: List[TradeRecord]) -> str:
"""Format trade history for prompts."""
if not trades:
return "No recent trades."
lines = []
for trade in trades:
outcome = trade.outcome.value if trade.outcome else "pending"
ret = f"{trade.returns:+.2%}" if trade.returns else "N/A"
lines.append(
f"- {trade.entry_time.strftime('%Y-%m-%d')}: "
f"{trade.direction.value.upper()} {trade.symbol} @ ${trade.entry_price:.2f} "
f"-> {outcome} ({ret})"
)
return "\n".join(lines)
def _format_similar_situations(self, scored_entries) -> str:
"""Format similar situations for prompts."""
if not scored_entries:
return "No similar past situations found."
lines = []
for scored in scored_entries[:3]:
entry = scored.entry
score = scored.combined_score
lines.append(f"- (relevance: {score:.2f}) {entry.content[:200]}...")
return "\n".join(lines)
def _extract_lessons(self, trades: List[TradeRecord]) -> str:
"""Extract lessons learned from trades."""
if not trades:
return "No lessons to extract."
lessons = []
# Analyze winning vs losing trades
winners = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE]
losers = [t for t in trades if t.outcome == TradeOutcome.LOSS]
if winners and losers:
win_avg_hold = sum(
(t.exit_time - t.entry_time).days
for t in winners if t.exit_time
) / max(1, len([t for t in winners if t.exit_time]))
loss_avg_hold = sum(
(t.exit_time - t.entry_time).days
for t in losers if t.exit_time
) / max(1, len([t for t in losers if t.exit_time]))
if win_avg_hold < loss_avg_hold:
lessons.append("Winners tend to show profits quickly; consider cutting losers earlier.")
elif loss_avg_hold < win_avg_hold:
lessons.append("Holding winners longer has been profitable; avoid taking profits too early.")
# Look for patterns in agent reasoning
for trade in trades:
if trade.reasoning and trade.outcome:
if trade.outcome == TradeOutcome.PROFITABLE:
if trade.reasoning.research_conclusion:
lessons.append("Trades following analyst conclusions have been profitable.")
break
elif trade.outcome == TradeOutcome.LOSS:
if trade.reasoning.risk_assessment:
lessons.append("Consider risk assessment more carefully on future trades.")
break
if not lessons:
return "Continue following current strategy."
return "\n".join(f"- {lesson}" for lesson in lessons[:3])
def _extract_risk_lessons(self, trades: List[TradeRecord]) -> str:
"""Extract risk-specific lessons from trades."""
if not trades:
return "No risk lessons available."
lessons = []
# Analyze large losses
large_losses = [
t for t in trades
if t.returns and t.returns < -0.1
]
if large_losses:
lessons.append(
f"Had {len(large_losses)} trades with >10% losses. "
"Consider tighter stop-losses."
)
# Check for position sizing patterns
trades_with_size = [t for t in trades if t.quantity]
if trades_with_size:
large_positions = [
t for t in trades_with_size
if t.quantity > 100 and t.outcome == TradeOutcome.LOSS
]
if large_positions:
lessons.append(
"Larger positions have shown higher loss frequency. "
"Consider scaling in gradually."
)
if not lessons:
return "No specific risk warnings from recent history."
return "\n".join(f"- {lesson}" for lesson in lessons)
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dictionary."""
return {
"trade_memory": self._trade_memory.to_dict(),
"risk_memory": self._risk_memory.to_dict(),
"situation_memory": self._situation_memory.to_dict(),
}
@classmethod
def from_dict(
cls,
data: Dict[str, Any],
embedding_function: Optional[Callable] = None,
) -> "AgentMemoryIntegration":
"""Create from dictionary."""
instance = cls(embedding_function=embedding_function)
if "trade_memory" in data:
instance._trade_memory = TradeHistoryMemory.from_dict(
data["trade_memory"],
embedding_function=embedding_function,
)
if "risk_memory" in data:
instance._risk_memory = RiskProfileMemory.from_dict(
data["risk_memory"],
embedding_function=embedding_function,
)
if "situation_memory" in data:
instance._situation_memory = LayeredMemory.from_dict(
data["situation_memory"],
embedding_function=embedding_function,
)
return instance
def create_memory_enhanced_prompt(
base_prompt: str,
context: MemoryContext,
context_types: Optional[List[ContextType]] = None,
) -> str:
"""Create a memory-enhanced prompt from a base prompt.
Args:
base_prompt: Original agent prompt
context: Memory context to include
context_types: Types of context to include
Returns:
Enhanced prompt with memory context
"""
if context.is_empty():
return base_prompt
memory_section = context.to_prompt_string(context_types)
return (
f"{base_prompt}\n\n"
f"---\n"
f"# Memory Context (Use this to inform your analysis)\n\n"
f"{memory_section}\n"
f"---"
)