601 lines
20 KiB
Python
601 lines
20 KiB
Python
"""Memory integration for agent prompts.
|
|
|
|
This module provides integration between the memory system and agent prompts,
|
|
enabling agents to access relevant historical context for better decision-making.
|
|
|
|
Issue #21: [MEM-20] Memory integration - retrieval in agent prompts
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Callable
|
|
from enum import Enum
|
|
|
|
from .layered_memory import LayeredMemory, MemoryEntry, MemoryConfig, ScoringWeights
|
|
from .trade_history import TradeHistoryMemory, TradeRecord, TradeOutcome, TradeDirection
|
|
from .risk_profiles import (
|
|
RiskProfileMemory,
|
|
RiskProfile,
|
|
RiskDecision,
|
|
RiskTolerance,
|
|
MarketRegime,
|
|
RiskCategory,
|
|
)
|
|
|
|
|
|
class ContextType(Enum):
|
|
"""Types of memory context that can be retrieved."""
|
|
TRADE_HISTORY = "trade_history"
|
|
RISK_PROFILE = "risk_profile"
|
|
SIMILAR_SITUATIONS = "similar_situations"
|
|
LESSONS_LEARNED = "lessons_learned"
|
|
ALL = "all"
|
|
|
|
|
|
@dataclass
|
|
class MemoryContext:
|
|
"""Memory context for agent prompts.
|
|
|
|
Attributes:
|
|
trade_history: Summary of relevant past trades
|
|
risk_context: Risk profile recommendations
|
|
similar_situations: Similar past situations and outcomes
|
|
lessons_learned: Key lessons from past trades
|
|
raw_trades: List of relevant TradeRecord objects
|
|
"""
|
|
trade_history: str = ""
|
|
risk_context: str = ""
|
|
similar_situations: str = ""
|
|
lessons_learned: str = ""
|
|
raw_trades: List[TradeRecord] = field(default_factory=list)
|
|
|
|
def to_prompt_string(self, include_types: Optional[List[ContextType]] = None) -> str:
|
|
"""Convert memory context to a string for agent prompts.
|
|
|
|
Args:
|
|
include_types: Types of context to include (default: all non-empty)
|
|
|
|
Returns:
|
|
Formatted string for agent prompts
|
|
"""
|
|
if include_types is None:
|
|
include_types = [ContextType.ALL]
|
|
|
|
parts = []
|
|
|
|
if ContextType.ALL in include_types or ContextType.TRADE_HISTORY in include_types:
|
|
if self.trade_history:
|
|
parts.append(f"## Recent Trade History\n{self.trade_history}")
|
|
|
|
if ContextType.ALL in include_types or ContextType.RISK_PROFILE in include_types:
|
|
if self.risk_context:
|
|
parts.append(f"## Risk Profile Context\n{self.risk_context}")
|
|
|
|
if ContextType.ALL in include_types or ContextType.SIMILAR_SITUATIONS in include_types:
|
|
if self.similar_situations:
|
|
parts.append(f"## Similar Past Situations\n{self.similar_situations}")
|
|
|
|
if ContextType.ALL in include_types or ContextType.LESSONS_LEARNED in include_types:
|
|
if self.lessons_learned:
|
|
parts.append(f"## Lessons Learned\n{self.lessons_learned}")
|
|
|
|
if not parts:
|
|
return "No relevant memory context available."
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
def is_empty(self) -> bool:
|
|
"""Check if context is empty."""
|
|
return not any([
|
|
self.trade_history,
|
|
self.risk_context,
|
|
self.similar_situations,
|
|
self.lessons_learned,
|
|
])
|
|
|
|
|
|
class AgentMemoryIntegration:
|
|
"""Integration layer between memory systems and agent prompts.
|
|
|
|
This class provides methods to retrieve relevant memory context
|
|
for different agents in the trading system.
|
|
|
|
Example:
|
|
>>> integration = AgentMemoryIntegration()
|
|
>>>
|
|
>>> # Get context for an analyst
|
|
>>> context = integration.get_analyst_context(
|
|
... ticker="AAPL",
|
|
... current_situation="Tech sector showing momentum",
|
|
... analyst_type="momentum",
|
|
... )
|
|
>>>
|
|
>>> # Use in prompt
|
|
>>> prompt = f"Analyze {ticker}. Memory context: {context.to_prompt_string()}"
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
trade_memory: Optional[TradeHistoryMemory] = None,
|
|
risk_memory: Optional[RiskProfileMemory] = None,
|
|
situation_memory: Optional[LayeredMemory] = None,
|
|
embedding_function: Optional[Callable] = None,
|
|
):
|
|
"""Initialize memory integration.
|
|
|
|
Args:
|
|
trade_memory: Trade history memory instance
|
|
risk_memory: Risk profile memory instance
|
|
situation_memory: General situation memory
|
|
embedding_function: Optional embedding function for similarity
|
|
"""
|
|
self._trade_memory = trade_memory or TradeHistoryMemory()
|
|
self._risk_memory = risk_memory or RiskProfileMemory()
|
|
self._situation_memory = situation_memory or LayeredMemory(
|
|
config=MemoryConfig(
|
|
weights=ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2)
|
|
),
|
|
embedding_function=embedding_function,
|
|
)
|
|
self._embedding_function = embedding_function
|
|
|
|
@property
|
|
def trade_memory(self) -> TradeHistoryMemory:
|
|
"""Access trade history memory."""
|
|
return self._trade_memory
|
|
|
|
@property
|
|
def risk_memory(self) -> RiskProfileMemory:
|
|
"""Access risk profile memory."""
|
|
return self._risk_memory
|
|
|
|
@property
|
|
def situation_memory(self) -> LayeredMemory:
|
|
"""Access situation memory."""
|
|
return self._situation_memory
|
|
|
|
def get_analyst_context(
|
|
self,
|
|
symbol: str,
|
|
current_situation: str,
|
|
analyst_type: str = "general",
|
|
lookback_days: int = 90,
|
|
max_trades: int = 5,
|
|
user_id: Optional[str] = None,
|
|
) -> MemoryContext:
|
|
"""Get memory context for an analyst agent.
|
|
|
|
Args:
|
|
symbol: Stock symbol being analyzed
|
|
current_situation: Current market situation description
|
|
analyst_type: Type of analyst (momentum, macro, etc.)
|
|
lookback_days: Days to look back for trades
|
|
max_trades: Maximum trades to include
|
|
user_id: User ID for risk profile
|
|
|
|
Returns:
|
|
MemoryContext with relevant information
|
|
"""
|
|
context = MemoryContext()
|
|
|
|
# Get relevant past trades for this symbol
|
|
trades = self._trade_memory.get_trades_by_symbol(symbol)
|
|
|
|
if trades:
|
|
# Filter to lookback period
|
|
cutoff = datetime.now() - timedelta(days=lookback_days)
|
|
recent_trades = [t for t in trades if t.entry_time >= cutoff][:max_trades]
|
|
|
|
if recent_trades:
|
|
context.raw_trades = recent_trades
|
|
context.trade_history = self._format_trade_history(recent_trades)
|
|
context.lessons_learned = self._extract_lessons(recent_trades)
|
|
|
|
# Get similar situations
|
|
similar = self._situation_memory.retrieve(
|
|
query=current_situation,
|
|
top_k=3,
|
|
tags=[analyst_type] if analyst_type != "general" else None,
|
|
)
|
|
|
|
if similar:
|
|
context.similar_situations = self._format_similar_situations(similar)
|
|
|
|
return context
|
|
|
|
def get_trader_context(
|
|
self,
|
|
symbol: str,
|
|
current_situation: str,
|
|
proposed_action: str,
|
|
market_regime: Optional[MarketRegime] = None,
|
|
user_id: Optional[str] = None,
|
|
) -> MemoryContext:
|
|
"""Get memory context for the trader agent.
|
|
|
|
Args:
|
|
symbol: Stock symbol being traded
|
|
current_situation: Current market situation
|
|
proposed_action: Proposed trade action (buy/sell/hold)
|
|
market_regime: Current market regime
|
|
user_id: User ID for risk profile
|
|
|
|
Returns:
|
|
MemoryContext with relevant information
|
|
"""
|
|
context = MemoryContext()
|
|
|
|
# Get past trades for this symbol
|
|
trades = self._trade_memory.get_trades_by_symbol(symbol)
|
|
if trades:
|
|
context.raw_trades = trades[:5]
|
|
context.trade_history = self._format_trade_history(trades[:5])
|
|
context.lessons_learned = self._extract_lessons(trades)
|
|
|
|
# Get risk profile context
|
|
if market_regime:
|
|
risk_level, explanation = self._risk_memory.recommend_risk_level(
|
|
category=RiskCategory.POSITION_SIZE,
|
|
market_regime=market_regime,
|
|
context=current_situation,
|
|
user_id=user_id,
|
|
)
|
|
profile = self._risk_memory.get_or_create_profile(user_id)
|
|
context.risk_context = (
|
|
f"Recommended risk level: {risk_level:.2f}\n"
|
|
f"Base tolerance: {profile.base_tolerance.value}\n"
|
|
f"Reasoning: {explanation}"
|
|
)
|
|
|
|
# Get similar trading situations
|
|
similar = self._situation_memory.retrieve(
|
|
query=f"{current_situation} {proposed_action}",
|
|
top_k=3,
|
|
)
|
|
|
|
if similar:
|
|
context.similar_situations = self._format_similar_situations(similar)
|
|
|
|
return context
|
|
|
|
def get_risk_manager_context(
|
|
self,
|
|
symbol: str,
|
|
proposed_trade: str,
|
|
position_size: float,
|
|
market_regime: Optional[MarketRegime] = None,
|
|
user_id: Optional[str] = None,
|
|
) -> MemoryContext:
|
|
"""Get memory context for risk management agent.
|
|
|
|
Args:
|
|
symbol: Stock symbol
|
|
proposed_trade: Proposed trade description
|
|
position_size: Proposed position size
|
|
market_regime: Current market regime
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
MemoryContext with risk-focused information
|
|
"""
|
|
context = MemoryContext()
|
|
|
|
# Get past trades with outcome statistics
|
|
trades = self._trade_memory.get_trades_by_symbol(symbol)
|
|
if trades:
|
|
winning = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE]
|
|
losing = [t for t in trades if t.outcome == TradeOutcome.LOSS]
|
|
|
|
win_rate = len(winning) / len(trades) if trades else 0
|
|
avg_return = sum(
|
|
t.returns or 0 for t in trades if t.returns
|
|
) / max(1, len([t for t in trades if t.returns]))
|
|
|
|
context.trade_history = (
|
|
f"Trading history for {symbol}:\n"
|
|
f"- Total trades: {len(trades)}\n"
|
|
f"- Win rate: {win_rate:.1%}\n"
|
|
f"- Average return: {avg_return:.2%}\n"
|
|
f"- Winners: {len(winning)}, Losers: {len(losing)}"
|
|
)
|
|
|
|
# Extract risk lessons
|
|
context.lessons_learned = self._extract_risk_lessons(trades)
|
|
|
|
# Get risk profile and recommendations
|
|
if market_regime:
|
|
profile = self._risk_memory.get_or_create_profile(user_id)
|
|
adjusted_tolerance = profile.get_adjusted_tolerance(market_regime)
|
|
|
|
risk_level, explanation = self._risk_memory.recommend_risk_level(
|
|
category=RiskCategory.POSITION_SIZE,
|
|
market_regime=market_regime,
|
|
context=proposed_trade,
|
|
user_id=user_id,
|
|
)
|
|
|
|
context.risk_context = (
|
|
f"User risk profile:\n"
|
|
f"- Base tolerance: {profile.base_tolerance.value}\n"
|
|
f"- Adjusted for {market_regime.value}: {adjusted_tolerance.value}\n"
|
|
f"- Max drawdown tolerance: {profile.max_drawdown_tolerance:.1%}\n"
|
|
f"- Recommended risk level: {risk_level:.2f}\n"
|
|
f"- Reasoning: {explanation}"
|
|
)
|
|
|
|
return context
|
|
|
|
def record_trade_outcome(
|
|
self,
|
|
trade: TradeRecord,
|
|
situation_context: str,
|
|
lesson_learned: Optional[str] = None,
|
|
) -> None:
|
|
"""Record a trade outcome for future reference.
|
|
|
|
Args:
|
|
trade: The completed trade record
|
|
situation_context: Description of the market situation
|
|
lesson_learned: Optional lesson to remember
|
|
"""
|
|
# Record in trade memory
|
|
self._trade_memory.record_trade(trade)
|
|
|
|
# Record situation for future similarity matching
|
|
importance = 0.5
|
|
if trade.returns:
|
|
# Higher importance for significant outcomes
|
|
importance = min(1.0, 0.5 + abs(trade.returns))
|
|
|
|
entry_content = (
|
|
f"Trade: {trade.direction.value} {trade.symbol} at {trade.entry_price}. "
|
|
f"Outcome: {trade.outcome.value if trade.outcome else 'pending'}. "
|
|
f"Context: {situation_context}"
|
|
)
|
|
|
|
if lesson_learned:
|
|
entry_content += f"\nLesson: {lesson_learned}"
|
|
|
|
entry = MemoryEntry.create(
|
|
content=entry_content,
|
|
metadata={
|
|
"trade_id": trade.id,
|
|
"symbol": trade.symbol,
|
|
"direction": trade.direction.value,
|
|
"outcome": trade.outcome.value if trade.outcome else None,
|
|
"return": trade.returns,
|
|
},
|
|
importance=importance,
|
|
tags=[trade.symbol, trade.direction.value],
|
|
)
|
|
|
|
self._situation_memory.add(entry)
|
|
|
|
def record_risk_decision(
|
|
self,
|
|
category: RiskCategory,
|
|
risk_level: float,
|
|
market_regime: MarketRegime,
|
|
context: str,
|
|
user_id: Optional[str] = None,
|
|
) -> str:
|
|
"""Record a risk decision for learning.
|
|
|
|
Args:
|
|
category: Risk category
|
|
risk_level: Risk level chosen
|
|
market_regime: Current market regime
|
|
context: Decision context
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
Decision ID
|
|
"""
|
|
decision = RiskDecision.create(
|
|
category=category,
|
|
risk_level=risk_level,
|
|
market_regime=market_regime,
|
|
context=context,
|
|
)
|
|
|
|
return self._risk_memory.record_decision(decision, user_id)
|
|
|
|
def evaluate_risk_decision(
|
|
self,
|
|
decision_id: str,
|
|
outcome: str,
|
|
outcome_score: float,
|
|
was_appropriate: bool,
|
|
) -> None:
|
|
"""Evaluate a past risk decision.
|
|
|
|
Args:
|
|
decision_id: Decision ID to evaluate
|
|
outcome: What happened
|
|
outcome_score: Outcome score (-1 to 1)
|
|
was_appropriate: Whether decision was appropriate
|
|
"""
|
|
self._risk_memory.evaluate_decision(
|
|
decision_id=decision_id,
|
|
outcome=outcome,
|
|
outcome_score=outcome_score,
|
|
was_appropriate=was_appropriate,
|
|
)
|
|
|
|
def _format_trade_history(self, trades: List[TradeRecord]) -> str:
|
|
"""Format trade history for prompts."""
|
|
if not trades:
|
|
return "No recent trades."
|
|
|
|
lines = []
|
|
for trade in trades:
|
|
outcome = trade.outcome.value if trade.outcome else "pending"
|
|
ret = f"{trade.returns:+.2%}" if trade.returns else "N/A"
|
|
lines.append(
|
|
f"- {trade.entry_time.strftime('%Y-%m-%d')}: "
|
|
f"{trade.direction.value.upper()} {trade.symbol} @ ${trade.entry_price:.2f} "
|
|
f"-> {outcome} ({ret})"
|
|
)
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _format_similar_situations(self, scored_entries) -> str:
|
|
"""Format similar situations for prompts."""
|
|
if not scored_entries:
|
|
return "No similar past situations found."
|
|
|
|
lines = []
|
|
for scored in scored_entries[:3]:
|
|
entry = scored.entry
|
|
score = scored.combined_score
|
|
lines.append(f"- (relevance: {score:.2f}) {entry.content[:200]}...")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _extract_lessons(self, trades: List[TradeRecord]) -> str:
|
|
"""Extract lessons learned from trades."""
|
|
if not trades:
|
|
return "No lessons to extract."
|
|
|
|
lessons = []
|
|
|
|
# Analyze winning vs losing trades
|
|
winners = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE]
|
|
losers = [t for t in trades if t.outcome == TradeOutcome.LOSS]
|
|
|
|
if winners and losers:
|
|
win_avg_hold = sum(
|
|
(t.exit_time - t.entry_time).days
|
|
for t in winners if t.exit_time
|
|
) / max(1, len([t for t in winners if t.exit_time]))
|
|
|
|
loss_avg_hold = sum(
|
|
(t.exit_time - t.entry_time).days
|
|
for t in losers if t.exit_time
|
|
) / max(1, len([t for t in losers if t.exit_time]))
|
|
|
|
if win_avg_hold < loss_avg_hold:
|
|
lessons.append("Winners tend to show profits quickly; consider cutting losers earlier.")
|
|
elif loss_avg_hold < win_avg_hold:
|
|
lessons.append("Holding winners longer has been profitable; avoid taking profits too early.")
|
|
|
|
# Look for patterns in agent reasoning
|
|
for trade in trades:
|
|
if trade.reasoning and trade.outcome:
|
|
if trade.outcome == TradeOutcome.PROFITABLE:
|
|
if trade.reasoning.research_conclusion:
|
|
lessons.append("Trades following analyst conclusions have been profitable.")
|
|
break
|
|
elif trade.outcome == TradeOutcome.LOSS:
|
|
if trade.reasoning.risk_assessment:
|
|
lessons.append("Consider risk assessment more carefully on future trades.")
|
|
break
|
|
|
|
if not lessons:
|
|
return "Continue following current strategy."
|
|
|
|
return "\n".join(f"- {lesson}" for lesson in lessons[:3])
|
|
|
|
def _extract_risk_lessons(self, trades: List[TradeRecord]) -> str:
|
|
"""Extract risk-specific lessons from trades."""
|
|
if not trades:
|
|
return "No risk lessons available."
|
|
|
|
lessons = []
|
|
|
|
# Analyze large losses
|
|
large_losses = [
|
|
t for t in trades
|
|
if t.returns and t.returns < -0.1
|
|
]
|
|
|
|
if large_losses:
|
|
lessons.append(
|
|
f"Had {len(large_losses)} trades with >10% losses. "
|
|
"Consider tighter stop-losses."
|
|
)
|
|
|
|
# Check for position sizing patterns
|
|
trades_with_size = [t for t in trades if t.quantity]
|
|
if trades_with_size:
|
|
large_positions = [
|
|
t for t in trades_with_size
|
|
if t.quantity > 100 and t.outcome == TradeOutcome.LOSS
|
|
]
|
|
if large_positions:
|
|
lessons.append(
|
|
"Larger positions have shown higher loss frequency. "
|
|
"Consider scaling in gradually."
|
|
)
|
|
|
|
if not lessons:
|
|
return "No specific risk warnings from recent history."
|
|
|
|
return "\n".join(f"- {lesson}" for lesson in lessons)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Serialize to dictionary."""
|
|
return {
|
|
"trade_memory": self._trade_memory.to_dict(),
|
|
"risk_memory": self._risk_memory.to_dict(),
|
|
"situation_memory": self._situation_memory.to_dict(),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
data: Dict[str, Any],
|
|
embedding_function: Optional[Callable] = None,
|
|
) -> "AgentMemoryIntegration":
|
|
"""Create from dictionary."""
|
|
instance = cls(embedding_function=embedding_function)
|
|
|
|
if "trade_memory" in data:
|
|
instance._trade_memory = TradeHistoryMemory.from_dict(
|
|
data["trade_memory"],
|
|
embedding_function=embedding_function,
|
|
)
|
|
|
|
if "risk_memory" in data:
|
|
instance._risk_memory = RiskProfileMemory.from_dict(
|
|
data["risk_memory"],
|
|
embedding_function=embedding_function,
|
|
)
|
|
|
|
if "situation_memory" in data:
|
|
instance._situation_memory = LayeredMemory.from_dict(
|
|
data["situation_memory"],
|
|
embedding_function=embedding_function,
|
|
)
|
|
|
|
return instance
|
|
|
|
|
|
def create_memory_enhanced_prompt(
|
|
base_prompt: str,
|
|
context: MemoryContext,
|
|
context_types: Optional[List[ContextType]] = None,
|
|
) -> str:
|
|
"""Create a memory-enhanced prompt from a base prompt.
|
|
|
|
Args:
|
|
base_prompt: Original agent prompt
|
|
context: Memory context to include
|
|
context_types: Types of context to include
|
|
|
|
Returns:
|
|
Enhanced prompt with memory context
|
|
"""
|
|
if context.is_empty():
|
|
return base_prompt
|
|
|
|
memory_section = context.to_prompt_string(context_types)
|
|
|
|
return (
|
|
f"{base_prompt}\n\n"
|
|
f"---\n"
|
|
f"# Memory Context (Use this to inform your analysis)\n\n"
|
|
f"{memory_section}\n"
|
|
f"---"
|
|
)
|