feat(memory): add memory integration for agent prompts - Fixes #21

This commit is contained in:
Andrew Kaszubski 2025-12-26 20:40:21 +11:00
parent 25c31d5f5d
commit 4f6f7c1c14
4 changed files with 1166 additions and 0 deletions

View File

@ -0,0 +1,548 @@
"""Tests for Memory Integration module.
Issue #21: [MEM-20] Memory integration - retrieval in agent prompts
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from tradingagents.memory.integration import (
AgentMemoryIntegration,
MemoryContext,
ContextType,
create_memory_enhanced_prompt,
)
from tradingagents.memory.trade_history import (
TradeRecord,
TradeOutcome,
TradeDirection,
SignalStrength,
AgentReasoning,
)
from tradingagents.memory.risk_profiles import (
RiskCategory,
MarketRegime,
RiskTolerance,
RiskProfile,
)
# =============================================================================
# MemoryContext Tests
# =============================================================================
class TestMemoryContext:
"""Tests for MemoryContext dataclass."""
def test_empty_context(self):
"""Test empty context detection."""
context = MemoryContext()
assert context.is_empty()
def test_non_empty_context(self):
"""Test non-empty context detection."""
context = MemoryContext(trade_history="Some trade history")
assert not context.is_empty()
def test_to_prompt_string_empty(self):
"""Test prompt string for empty context."""
context = MemoryContext()
result = context.to_prompt_string()
assert "No relevant memory context" in result
def test_to_prompt_string_with_history(self):
"""Test prompt string with trade history."""
context = MemoryContext(trade_history="AAPL: +5% last week")
result = context.to_prompt_string()
assert "Recent Trade History" in result
assert "AAPL: +5% last week" in result
def test_to_prompt_string_with_all_sections(self):
"""Test prompt string with all sections."""
context = MemoryContext(
trade_history="Trade history content",
risk_context="Risk context content",
similar_situations="Similar situations content",
lessons_learned="Lessons learned content",
)
result = context.to_prompt_string()
assert "Recent Trade History" in result
assert "Risk Profile Context" in result
assert "Similar Past Situations" in result
assert "Lessons Learned" in result
def test_to_prompt_string_filter_by_type(self):
"""Test filtering context by type."""
context = MemoryContext(
trade_history="Trade history content",
risk_context="Risk context content",
lessons_learned="Lessons learned content",
)
# Only trade history
result = context.to_prompt_string([ContextType.TRADE_HISTORY])
assert "Recent Trade History" in result
assert "Risk Profile Context" not in result
assert "Lessons Learned" not in result
def test_to_prompt_string_multiple_types(self):
"""Test filtering with multiple types."""
context = MemoryContext(
trade_history="Trade history content",
risk_context="Risk context content",
lessons_learned="Lessons learned content",
)
result = context.to_prompt_string([
ContextType.TRADE_HISTORY,
ContextType.RISK_PROFILE,
])
assert "Recent Trade History" in result
assert "Risk Profile Context" in result
assert "Lessons Learned" not in result
# =============================================================================
# AgentMemoryIntegration Tests
# =============================================================================
class TestAgentMemoryIntegration:
"""Tests for AgentMemoryIntegration class."""
def test_create_integration(self):
"""Test creating integration instance."""
integration = AgentMemoryIntegration()
assert integration.trade_memory is not None
assert integration.risk_memory is not None
assert integration.situation_memory is not None
def test_get_analyst_context_empty(self):
"""Test getting analyst context with no history."""
integration = AgentMemoryIntegration()
context = integration.get_analyst_context(
symbol="AAPL",
current_situation="Tech sector showing strength",
analyst_type="momentum",
)
# Should return context (may be empty)
assert isinstance(context, MemoryContext)
def test_get_analyst_context_with_trades(self):
"""Test getting analyst context with trade history."""
integration = AgentMemoryIntegration()
# Add some trades
trade = TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.LONG,
entry_price=150.0,
quantity=100,
)
trade.close(exit_price=160.0)
integration.trade_memory.record_trade(trade)
context = integration.get_analyst_context(
symbol="AAPL",
current_situation="Tech rally",
analyst_type="momentum",
)
assert len(context.raw_trades) > 0
assert "AAPL" in context.trade_history
def test_get_trader_context_empty(self):
"""Test getting trader context with no history."""
integration = AgentMemoryIntegration()
context = integration.get_trader_context(
symbol="TSLA",
current_situation="EV sector momentum",
proposed_action="buy",
)
assert isinstance(context, MemoryContext)
def test_get_trader_context_with_regime(self):
"""Test getting trader context with market regime."""
integration = AgentMemoryIntegration()
# Set up a profile
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
integration.risk_memory.set_profile(profile)
context = integration.get_trader_context(
symbol="AAPL",
current_situation="Bull market",
proposed_action="buy",
market_regime=MarketRegime.BULL,
)
assert "Recommended risk level" in context.risk_context
assert "moderate" in context.risk_context.lower()
def test_get_risk_manager_context(self):
"""Test getting risk manager context."""
integration = AgentMemoryIntegration()
# Add some trades with outcomes
for i in range(5):
trade = TradeRecord.create(
symbol="MSFT",
direction=TradeDirection.LONG,
entry_price=300.0 + i,
quantity=100,
)
# Some winners, some losers
if i % 2 == 0:
trade.close(exit_price=310.0 + i)
else:
trade.close(exit_price=290.0 + i)
integration.trade_memory.record_trade(trade)
context = integration.get_risk_manager_context(
symbol="MSFT",
proposed_trade="Buy 100 shares at $305",
position_size=30500,
market_regime=MarketRegime.BULL,
)
assert "Trading history for MSFT" in context.trade_history
assert "Win rate" in context.trade_history
def test_record_trade_outcome(self):
"""Test recording a trade outcome."""
integration = AgentMemoryIntegration()
trade = TradeRecord.create(
symbol="GOOGL",
direction=TradeDirection.LONG,
entry_price=140.0,
quantity=50,
)
trade.close(exit_price=150.0)
integration.record_trade_outcome(
trade=trade,
situation_context="Tech sector showing AI momentum",
lesson_learned="AI momentum trades tend to work well",
)
# Trade should be in memory
trades = integration.trade_memory.get_trades_by_symbol("GOOGL")
assert len(trades) == 1
assert trades[0].symbol == "GOOGL"
# Situation should be recorded
assert integration.situation_memory.count() == 1
def test_record_risk_decision(self):
"""Test recording a risk decision."""
integration = AgentMemoryIntegration()
decision_id = integration.record_risk_decision(
category=RiskCategory.POSITION_SIZE,
risk_level=0.6,
market_regime=MarketRegime.BULL,
context="Strong momentum, increasing position",
)
assert decision_id is not None
# Decision should be recorded
decision = integration.risk_memory.get_decision(decision_id)
assert decision is not None
assert decision.risk_level == 0.6
def test_evaluate_risk_decision(self):
"""Test evaluating a risk decision."""
integration = AgentMemoryIntegration()
decision_id = integration.record_risk_decision(
category=RiskCategory.LEVERAGE,
risk_level=0.7,
market_regime=MarketRegime.LOW_VOLATILITY,
context="Low vol, using leverage",
)
integration.evaluate_risk_decision(
decision_id=decision_id,
outcome="Profitable trade with leverage",
outcome_score=0.6,
was_appropriate=True,
)
decision = integration.risk_memory.get_decision(decision_id)
assert decision.was_appropriate is True
assert decision.outcome_score == 0.6
def test_to_dict_and_from_dict(self):
"""Test serialization roundtrip."""
integration = AgentMemoryIntegration()
# Add some data
trade = TradeRecord.create(
symbol="NVDA",
direction=TradeDirection.LONG,
entry_price=500.0,
quantity=20,
)
trade.close(exit_price=550.0)
integration.trade_memory.record_trade(trade)
profile = RiskProfile(user_id="test", base_tolerance=RiskTolerance.AGGRESSIVE)
integration.risk_memory.set_profile(profile)
# Serialize and restore
data = integration.to_dict()
restored = AgentMemoryIntegration.from_dict(data)
# Verify data preserved
trades = restored.trade_memory.get_trades_by_symbol("NVDA")
assert len(trades) == 1
profile = restored.risk_memory.get_profile("test")
assert profile is not None
assert profile.base_tolerance == RiskTolerance.AGGRESSIVE
# =============================================================================
# Helper Function Tests
# =============================================================================
class TestCreateMemoryEnhancedPrompt:
"""Tests for create_memory_enhanced_prompt function."""
def test_empty_context_returns_base(self):
"""Test that empty context returns base prompt unchanged."""
base_prompt = "Analyze the stock"
context = MemoryContext()
result = create_memory_enhanced_prompt(base_prompt, context)
assert result == base_prompt
def test_adds_memory_section(self):
"""Test that memory section is added."""
base_prompt = "Analyze the stock"
context = MemoryContext(trade_history="Previous trade: +5%")
result = create_memory_enhanced_prompt(base_prompt, context)
assert "Analyze the stock" in result
assert "Memory Context" in result
assert "Previous trade: +5%" in result
def test_respects_context_types(self):
"""Test filtering by context types."""
base_prompt = "Analyze"
context = MemoryContext(
trade_history="Trade history",
risk_context="Risk context",
)
result = create_memory_enhanced_prompt(
base_prompt,
context,
[ContextType.TRADE_HISTORY],
)
assert "Trade history" in result
assert "Risk context" not in result
# =============================================================================
# Integration Tests
# =============================================================================
class TestAgentMemoryIntegrationWorkflow:
"""Integration tests for complete memory workflow."""
def test_full_trading_workflow(self):
"""Test a complete trading workflow with memory."""
integration = AgentMemoryIntegration()
# 1. Set up user profile
profile = RiskProfile(
user_id="trader1",
base_tolerance=RiskTolerance.MODERATE,
)
integration.risk_memory.set_profile(profile)
# 2. Record some historical trades
historical_trades = [
("AAPL", TradeDirection.LONG, 150.0, 160.0),
("AAPL", TradeDirection.LONG, 155.0, 165.0),
("AAPL", TradeDirection.LONG, 160.0, 155.0), # Loss
]
for symbol, direction, entry, exit in historical_trades:
trade = TradeRecord.create(
symbol=symbol,
direction=direction,
entry_price=entry,
quantity=100,
)
trade.close(exit_price=exit)
integration.record_trade_outcome(
trade=trade,
situation_context=f"Trade in {symbol}",
)
# 3. Get analyst context for new analysis
analyst_context = integration.get_analyst_context(
symbol="AAPL",
current_situation="Apple showing momentum",
analyst_type="momentum",
)
assert len(analyst_context.raw_trades) == 3
assert "AAPL" in analyst_context.trade_history
# 4. Get trader context for decision
trader_context = integration.get_trader_context(
symbol="AAPL",
current_situation="Strong momentum signal",
proposed_action="buy",
market_regime=MarketRegime.BULL,
user_id="trader1",
)
assert "Recommended risk level" in trader_context.risk_context
# 5. Get risk manager context
risk_context = integration.get_risk_manager_context(
symbol="AAPL",
proposed_trade="Buy 100 shares",
position_size=16000,
market_regime=MarketRegime.BULL,
user_id="trader1",
)
assert "Win rate" in risk_context.trade_history
assert "moderate" in risk_context.risk_context.lower()
def test_memory_influences_recommendations(self):
"""Test that memory influences risk recommendations."""
integration = AgentMemoryIntegration()
# Set up profile
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
integration.risk_memory.set_profile(profile)
# Record successful high-risk decisions
for i in range(5):
decision_id = integration.record_risk_decision(
category=RiskCategory.POSITION_SIZE,
risk_level=0.7,
market_regime=MarketRegime.BULL,
context=f"Successful trade {i}",
)
integration.evaluate_risk_decision(
decision_id=decision_id,
outcome="Profitable",
outcome_score=0.6,
was_appropriate=True,
)
# Get recommendation - should be influenced by history
risk_level, explanation = integration.risk_memory.recommend_risk_level(
category=RiskCategory.POSITION_SIZE,
market_regime=MarketRegime.BULL,
context="Similar successful situation",
)
# With successful high-risk history, recommendation should be elevated
# Base moderate (0.375) + bull adjustment (0.1) = 0.475
# But with history of successful 0.7 decisions, should be higher
assert risk_level > 0.5
def test_lessons_extracted_from_trades(self):
"""Test that lessons are extracted from trade patterns."""
integration = AgentMemoryIntegration()
# Record winning trades with short hold times
for i in range(3):
trade = TradeRecord.create(
symbol="TSLA",
direction=TradeDirection.LONG,
entry_price=200.0,
quantity=50,
)
# Quick win - close the trade
trade.close(exit_price=220.0)
integration.trade_memory.record_trade(trade)
# Record losing trades with long hold times
for i in range(3):
trade = TradeRecord.create(
symbol="TSLA",
direction=TradeDirection.LONG,
entry_price=200.0,
quantity=50,
)
# Slow loss
trade.close(exit_price=180.0)
integration.trade_memory.record_trade(trade)
context = integration.get_analyst_context(
symbol="TSLA",
current_situation="EV sector",
analyst_type="general",
)
# Lessons should exist (may be strategy continuation or specific lesson)
assert context.lessons_learned != ""
class TestContextTypeFiltering:
"""Tests for context type filtering."""
def test_all_types(self):
"""Test ALL context type includes everything."""
context = MemoryContext(
trade_history="History",
risk_context="Risk",
similar_situations="Similar",
lessons_learned="Lessons",
)
result = context.to_prompt_string([ContextType.ALL])
assert "History" in result
assert "Risk" in result
assert "Similar" in result
assert "Lessons" in result
def test_single_type(self):
"""Test filtering to single type."""
context = MemoryContext(
trade_history="History",
risk_context="Risk",
)
result = context.to_prompt_string([ContextType.RISK_PROFILE])
assert "History" not in result
assert "Risk" in result
def test_empty_sections_not_included(self):
"""Test that empty sections aren't shown."""
context = MemoryContext(
trade_history="History",
risk_context="", # Empty
)
result = context.to_prompt_string([ContextType.ALL])
assert "Recent Trade History" in result
assert "Risk Profile Context" not in result

View File

@ -64,6 +64,11 @@ class AgentState(MessagesState):
macro_report: Annotated[str, "Report from the Macro Analyst"]
correlation_report: Annotated[str, "Report from the Correlation Analyst"]
# memory context (Issue #21)
memory_context: Annotated[str, "Memory context for agents including past trades and lessons"]
trade_history_context: Annotated[str, "Relevant past trade history for this ticker"]
risk_context: Annotated[str, "Risk profile context and recommendations"]
# researcher team discussion step
investment_debate_state: Annotated[
InvestDebateState, "Current state of the debate on if to invest or not"

View File

@ -8,6 +8,7 @@ This module provides a layered memory system with three scoring dimensions:
Issue #18: Layered memory - recency, relevancy, importance scoring
Issue #19: Trade history memory - outcomes, agent reasoning
Issue #20: Risk profiles memory - user preferences over time
Issue #21: Memory integration - retrieval in agent prompts
"""
from .layered_memory import (
@ -38,6 +39,13 @@ from .risk_profiles import (
RiskCategory,
)
from .integration import (
AgentMemoryIntegration,
MemoryContext,
ContextType,
create_memory_enhanced_prompt,
)
__all__ = [
# Layered Memory (Issue #18)
"LayeredMemory",
@ -61,4 +69,9 @@ __all__ = [
"RiskTolerance",
"MarketRegime",
"RiskCategory",
# Memory Integration (Issue #21)
"AgentMemoryIntegration",
"MemoryContext",
"ContextType",
"create_memory_enhanced_prompt",
]

View File

@ -0,0 +1,600 @@
"""Memory integration for agent prompts.
This module provides integration between the memory system and agent prompts,
enabling agents to access relevant historical context for better decision-making.
Issue #21: [MEM-20] Memory integration - retrieval in agent prompts
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable
from enum import Enum
from .layered_memory import LayeredMemory, MemoryEntry, MemoryConfig, ScoringWeights
from .trade_history import TradeHistoryMemory, TradeRecord, TradeOutcome, TradeDirection
from .risk_profiles import (
RiskProfileMemory,
RiskProfile,
RiskDecision,
RiskTolerance,
MarketRegime,
RiskCategory,
)
class ContextType(Enum):
"""Types of memory context that can be retrieved."""
TRADE_HISTORY = "trade_history"
RISK_PROFILE = "risk_profile"
SIMILAR_SITUATIONS = "similar_situations"
LESSONS_LEARNED = "lessons_learned"
ALL = "all"
@dataclass
class MemoryContext:
"""Memory context for agent prompts.
Attributes:
trade_history: Summary of relevant past trades
risk_context: Risk profile recommendations
similar_situations: Similar past situations and outcomes
lessons_learned: Key lessons from past trades
raw_trades: List of relevant TradeRecord objects
"""
trade_history: str = ""
risk_context: str = ""
similar_situations: str = ""
lessons_learned: str = ""
raw_trades: List[TradeRecord] = field(default_factory=list)
def to_prompt_string(self, include_types: Optional[List[ContextType]] = None) -> str:
"""Convert memory context to a string for agent prompts.
Args:
include_types: Types of context to include (default: all non-empty)
Returns:
Formatted string for agent prompts
"""
if include_types is None:
include_types = [ContextType.ALL]
parts = []
if ContextType.ALL in include_types or ContextType.TRADE_HISTORY in include_types:
if self.trade_history:
parts.append(f"## Recent Trade History\n{self.trade_history}")
if ContextType.ALL in include_types or ContextType.RISK_PROFILE in include_types:
if self.risk_context:
parts.append(f"## Risk Profile Context\n{self.risk_context}")
if ContextType.ALL in include_types or ContextType.SIMILAR_SITUATIONS in include_types:
if self.similar_situations:
parts.append(f"## Similar Past Situations\n{self.similar_situations}")
if ContextType.ALL in include_types or ContextType.LESSONS_LEARNED in include_types:
if self.lessons_learned:
parts.append(f"## Lessons Learned\n{self.lessons_learned}")
if not parts:
return "No relevant memory context available."
return "\n\n".join(parts)
def is_empty(self) -> bool:
"""Check if context is empty."""
return not any([
self.trade_history,
self.risk_context,
self.similar_situations,
self.lessons_learned,
])
class AgentMemoryIntegration:
"""Integration layer between memory systems and agent prompts.
This class provides methods to retrieve relevant memory context
for different agents in the trading system.
Example:
>>> integration = AgentMemoryIntegration()
>>>
>>> # Get context for an analyst
>>> context = integration.get_analyst_context(
... ticker="AAPL",
... current_situation="Tech sector showing momentum",
... analyst_type="momentum",
... )
>>>
>>> # Use in prompt
>>> prompt = f"Analyze {ticker}. Memory context: {context.to_prompt_string()}"
"""
def __init__(
self,
trade_memory: Optional[TradeHistoryMemory] = None,
risk_memory: Optional[RiskProfileMemory] = None,
situation_memory: Optional[LayeredMemory] = None,
embedding_function: Optional[Callable] = None,
):
"""Initialize memory integration.
Args:
trade_memory: Trade history memory instance
risk_memory: Risk profile memory instance
situation_memory: General situation memory
embedding_function: Optional embedding function for similarity
"""
self._trade_memory = trade_memory or TradeHistoryMemory()
self._risk_memory = risk_memory or RiskProfileMemory()
self._situation_memory = situation_memory or LayeredMemory(
config=MemoryConfig(
weights=ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2)
),
embedding_function=embedding_function,
)
self._embedding_function = embedding_function
@property
def trade_memory(self) -> TradeHistoryMemory:
"""Access trade history memory."""
return self._trade_memory
@property
def risk_memory(self) -> RiskProfileMemory:
"""Access risk profile memory."""
return self._risk_memory
@property
def situation_memory(self) -> LayeredMemory:
"""Access situation memory."""
return self._situation_memory
def get_analyst_context(
self,
symbol: str,
current_situation: str,
analyst_type: str = "general",
lookback_days: int = 90,
max_trades: int = 5,
user_id: Optional[str] = None,
) -> MemoryContext:
"""Get memory context for an analyst agent.
Args:
symbol: Stock symbol being analyzed
current_situation: Current market situation description
analyst_type: Type of analyst (momentum, macro, etc.)
lookback_days: Days to look back for trades
max_trades: Maximum trades to include
user_id: User ID for risk profile
Returns:
MemoryContext with relevant information
"""
context = MemoryContext()
# Get relevant past trades for this symbol
trades = self._trade_memory.get_trades_by_symbol(symbol)
if trades:
# Filter to lookback period
cutoff = datetime.now() - timedelta(days=lookback_days)
recent_trades = [t for t in trades if t.entry_time >= cutoff][:max_trades]
if recent_trades:
context.raw_trades = recent_trades
context.trade_history = self._format_trade_history(recent_trades)
context.lessons_learned = self._extract_lessons(recent_trades)
# Get similar situations
similar = self._situation_memory.retrieve(
query=current_situation,
top_k=3,
tags=[analyst_type] if analyst_type != "general" else None,
)
if similar:
context.similar_situations = self._format_similar_situations(similar)
return context
def get_trader_context(
self,
symbol: str,
current_situation: str,
proposed_action: str,
market_regime: Optional[MarketRegime] = None,
user_id: Optional[str] = None,
) -> MemoryContext:
"""Get memory context for the trader agent.
Args:
symbol: Stock symbol being traded
current_situation: Current market situation
proposed_action: Proposed trade action (buy/sell/hold)
market_regime: Current market regime
user_id: User ID for risk profile
Returns:
MemoryContext with relevant information
"""
context = MemoryContext()
# Get past trades for this symbol
trades = self._trade_memory.get_trades_by_symbol(symbol)
if trades:
context.raw_trades = trades[:5]
context.trade_history = self._format_trade_history(trades[:5])
context.lessons_learned = self._extract_lessons(trades)
# Get risk profile context
if market_regime:
risk_level, explanation = self._risk_memory.recommend_risk_level(
category=RiskCategory.POSITION_SIZE,
market_regime=market_regime,
context=current_situation,
user_id=user_id,
)
profile = self._risk_memory.get_or_create_profile(user_id)
context.risk_context = (
f"Recommended risk level: {risk_level:.2f}\n"
f"Base tolerance: {profile.base_tolerance.value}\n"
f"Reasoning: {explanation}"
)
# Get similar trading situations
similar = self._situation_memory.retrieve(
query=f"{current_situation} {proposed_action}",
top_k=3,
)
if similar:
context.similar_situations = self._format_similar_situations(similar)
return context
def get_risk_manager_context(
self,
symbol: str,
proposed_trade: str,
position_size: float,
market_regime: Optional[MarketRegime] = None,
user_id: Optional[str] = None,
) -> MemoryContext:
"""Get memory context for risk management agent.
Args:
symbol: Stock symbol
proposed_trade: Proposed trade description
position_size: Proposed position size
market_regime: Current market regime
user_id: User ID
Returns:
MemoryContext with risk-focused information
"""
context = MemoryContext()
# Get past trades with outcome statistics
trades = self._trade_memory.get_trades_by_symbol(symbol)
if trades:
winning = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE]
losing = [t for t in trades if t.outcome == TradeOutcome.LOSS]
win_rate = len(winning) / len(trades) if trades else 0
avg_return = sum(
t.returns or 0 for t in trades if t.returns
) / max(1, len([t for t in trades if t.returns]))
context.trade_history = (
f"Trading history for {symbol}:\n"
f"- Total trades: {len(trades)}\n"
f"- Win rate: {win_rate:.1%}\n"
f"- Average return: {avg_return:.2%}\n"
f"- Winners: {len(winning)}, Losers: {len(losing)}"
)
# Extract risk lessons
context.lessons_learned = self._extract_risk_lessons(trades)
# Get risk profile and recommendations
if market_regime:
profile = self._risk_memory.get_or_create_profile(user_id)
adjusted_tolerance = profile.get_adjusted_tolerance(market_regime)
risk_level, explanation = self._risk_memory.recommend_risk_level(
category=RiskCategory.POSITION_SIZE,
market_regime=market_regime,
context=proposed_trade,
user_id=user_id,
)
context.risk_context = (
f"User risk profile:\n"
f"- Base tolerance: {profile.base_tolerance.value}\n"
f"- Adjusted for {market_regime.value}: {adjusted_tolerance.value}\n"
f"- Max drawdown tolerance: {profile.max_drawdown_tolerance:.1%}\n"
f"- Recommended risk level: {risk_level:.2f}\n"
f"- Reasoning: {explanation}"
)
return context
def record_trade_outcome(
self,
trade: TradeRecord,
situation_context: str,
lesson_learned: Optional[str] = None,
) -> None:
"""Record a trade outcome for future reference.
Args:
trade: The completed trade record
situation_context: Description of the market situation
lesson_learned: Optional lesson to remember
"""
# Record in trade memory
self._trade_memory.record_trade(trade)
# Record situation for future similarity matching
importance = 0.5
if trade.returns:
# Higher importance for significant outcomes
importance = min(1.0, 0.5 + abs(trade.returns))
entry_content = (
f"Trade: {trade.direction.value} {trade.symbol} at {trade.entry_price}. "
f"Outcome: {trade.outcome.value if trade.outcome else 'pending'}. "
f"Context: {situation_context}"
)
if lesson_learned:
entry_content += f"\nLesson: {lesson_learned}"
entry = MemoryEntry.create(
content=entry_content,
metadata={
"trade_id": trade.id,
"symbol": trade.symbol,
"direction": trade.direction.value,
"outcome": trade.outcome.value if trade.outcome else None,
"return": trade.returns,
},
importance=importance,
tags=[trade.symbol, trade.direction.value],
)
self._situation_memory.add(entry)
def record_risk_decision(
self,
category: RiskCategory,
risk_level: float,
market_regime: MarketRegime,
context: str,
user_id: Optional[str] = None,
) -> str:
"""Record a risk decision for learning.
Args:
category: Risk category
risk_level: Risk level chosen
market_regime: Current market regime
context: Decision context
user_id: User ID
Returns:
Decision ID
"""
decision = RiskDecision.create(
category=category,
risk_level=risk_level,
market_regime=market_regime,
context=context,
)
return self._risk_memory.record_decision(decision, user_id)
def evaluate_risk_decision(
self,
decision_id: str,
outcome: str,
outcome_score: float,
was_appropriate: bool,
) -> None:
"""Evaluate a past risk decision.
Args:
decision_id: Decision ID to evaluate
outcome: What happened
outcome_score: Outcome score (-1 to 1)
was_appropriate: Whether decision was appropriate
"""
self._risk_memory.evaluate_decision(
decision_id=decision_id,
outcome=outcome,
outcome_score=outcome_score,
was_appropriate=was_appropriate,
)
def _format_trade_history(self, trades: List[TradeRecord]) -> str:
"""Format trade history for prompts."""
if not trades:
return "No recent trades."
lines = []
for trade in trades:
outcome = trade.outcome.value if trade.outcome else "pending"
ret = f"{trade.returns:+.2%}" if trade.returns else "N/A"
lines.append(
f"- {trade.entry_time.strftime('%Y-%m-%d')}: "
f"{trade.direction.value.upper()} {trade.symbol} @ ${trade.entry_price:.2f} "
f"-> {outcome} ({ret})"
)
return "\n".join(lines)
def _format_similar_situations(self, scored_entries) -> str:
"""Format similar situations for prompts."""
if not scored_entries:
return "No similar past situations found."
lines = []
for scored in scored_entries[:3]:
entry = scored.entry
score = scored.combined_score
lines.append(f"- (relevance: {score:.2f}) {entry.content[:200]}...")
return "\n".join(lines)
def _extract_lessons(self, trades: List[TradeRecord]) -> str:
"""Extract lessons learned from trades."""
if not trades:
return "No lessons to extract."
lessons = []
# Analyze winning vs losing trades
winners = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE]
losers = [t for t in trades if t.outcome == TradeOutcome.LOSS]
if winners and losers:
win_avg_hold = sum(
(t.exit_time - t.entry_time).days
for t in winners if t.exit_time
) / max(1, len([t for t in winners if t.exit_time]))
loss_avg_hold = sum(
(t.exit_time - t.entry_time).days
for t in losers if t.exit_time
) / max(1, len([t for t in losers if t.exit_time]))
if win_avg_hold < loss_avg_hold:
lessons.append("Winners tend to show profits quickly; consider cutting losers earlier.")
elif loss_avg_hold < win_avg_hold:
lessons.append("Holding winners longer has been profitable; avoid taking profits too early.")
# Look for patterns in agent reasoning
for trade in trades:
if trade.reasoning and trade.outcome:
if trade.outcome == TradeOutcome.PROFITABLE:
if trade.reasoning.research_conclusion:
lessons.append("Trades following analyst conclusions have been profitable.")
break
elif trade.outcome == TradeOutcome.LOSS:
if trade.reasoning.risk_assessment:
lessons.append("Consider risk assessment more carefully on future trades.")
break
if not lessons:
return "Continue following current strategy."
return "\n".join(f"- {lesson}" for lesson in lessons[:3])
def _extract_risk_lessons(self, trades: List[TradeRecord]) -> str:
"""Extract risk-specific lessons from trades."""
if not trades:
return "No risk lessons available."
lessons = []
# Analyze large losses
large_losses = [
t for t in trades
if t.returns and t.returns < -0.1
]
if large_losses:
lessons.append(
f"Had {len(large_losses)} trades with >10% losses. "
"Consider tighter stop-losses."
)
# Check for position sizing patterns
trades_with_size = [t for t in trades if t.quantity]
if trades_with_size:
large_positions = [
t for t in trades_with_size
if t.quantity > 100 and t.outcome == TradeOutcome.LOSS
]
if large_positions:
lessons.append(
"Larger positions have shown higher loss frequency. "
"Consider scaling in gradually."
)
if not lessons:
return "No specific risk warnings from recent history."
return "\n".join(f"- {lesson}" for lesson in lessons)
def to_dict(self) -> Dict[str, Any]:
"""Serialize to dictionary."""
return {
"trade_memory": self._trade_memory.to_dict(),
"risk_memory": self._risk_memory.to_dict(),
"situation_memory": self._situation_memory.to_dict(),
}
@classmethod
def from_dict(
cls,
data: Dict[str, Any],
embedding_function: Optional[Callable] = None,
) -> "AgentMemoryIntegration":
"""Create from dictionary."""
instance = cls(embedding_function=embedding_function)
if "trade_memory" in data:
instance._trade_memory = TradeHistoryMemory.from_dict(
data["trade_memory"],
embedding_function=embedding_function,
)
if "risk_memory" in data:
instance._risk_memory = RiskProfileMemory.from_dict(
data["risk_memory"],
embedding_function=embedding_function,
)
if "situation_memory" in data:
instance._situation_memory = LayeredMemory.from_dict(
data["situation_memory"],
embedding_function=embedding_function,
)
return instance
def create_memory_enhanced_prompt(
base_prompt: str,
context: MemoryContext,
context_types: Optional[List[ContextType]] = None,
) -> str:
"""Create a memory-enhanced prompt from a base prompt.
Args:
base_prompt: Original agent prompt
context: Memory context to include
context_types: Types of context to include
Returns:
Enhanced prompt with memory context
"""
if context.is_empty():
return base_prompt
memory_section = context.to_prompt_string(context_types)
return (
f"{base_prompt}\n\n"
f"---\n"
f"# Memory Context (Use this to inform your analysis)\n\n"
f"{memory_section}\n"
f"---"
)