diff --git a/tests/unit/memory/test_integration.py b/tests/unit/memory/test_integration.py new file mode 100644 index 00000000..461facfe --- /dev/null +++ b/tests/unit/memory/test_integration.py @@ -0,0 +1,548 @@ +"""Tests for Memory Integration module. + +Issue #21: [MEM-20] Memory integration - retrieval in agent prompts +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch + +from tradingagents.memory.integration import ( + AgentMemoryIntegration, + MemoryContext, + ContextType, + create_memory_enhanced_prompt, +) +from tradingagents.memory.trade_history import ( + TradeRecord, + TradeOutcome, + TradeDirection, + SignalStrength, + AgentReasoning, +) +from tradingagents.memory.risk_profiles import ( + RiskCategory, + MarketRegime, + RiskTolerance, + RiskProfile, +) + + +# ============================================================================= +# MemoryContext Tests +# ============================================================================= + + +class TestMemoryContext: + """Tests for MemoryContext dataclass.""" + + def test_empty_context(self): + """Test empty context detection.""" + context = MemoryContext() + assert context.is_empty() + + def test_non_empty_context(self): + """Test non-empty context detection.""" + context = MemoryContext(trade_history="Some trade history") + assert not context.is_empty() + + def test_to_prompt_string_empty(self): + """Test prompt string for empty context.""" + context = MemoryContext() + result = context.to_prompt_string() + assert "No relevant memory context" in result + + def test_to_prompt_string_with_history(self): + """Test prompt string with trade history.""" + context = MemoryContext(trade_history="AAPL: +5% last week") + result = context.to_prompt_string() + + assert "Recent Trade History" in result + assert "AAPL: +5% last week" in result + + def test_to_prompt_string_with_all_sections(self): + """Test prompt string with all sections.""" + context = MemoryContext( + trade_history="Trade history content", + risk_context="Risk context content", + similar_situations="Similar situations content", + lessons_learned="Lessons learned content", + ) + result = context.to_prompt_string() + + assert "Recent Trade History" in result + assert "Risk Profile Context" in result + assert "Similar Past Situations" in result + assert "Lessons Learned" in result + + def test_to_prompt_string_filter_by_type(self): + """Test filtering context by type.""" + context = MemoryContext( + trade_history="Trade history content", + risk_context="Risk context content", + lessons_learned="Lessons learned content", + ) + + # Only trade history + result = context.to_prompt_string([ContextType.TRADE_HISTORY]) + assert "Recent Trade History" in result + assert "Risk Profile Context" not in result + assert "Lessons Learned" not in result + + def test_to_prompt_string_multiple_types(self): + """Test filtering with multiple types.""" + context = MemoryContext( + trade_history="Trade history content", + risk_context="Risk context content", + lessons_learned="Lessons learned content", + ) + + result = context.to_prompt_string([ + ContextType.TRADE_HISTORY, + ContextType.RISK_PROFILE, + ]) + + assert "Recent Trade History" in result + assert "Risk Profile Context" in result + assert "Lessons Learned" not in result + + +# ============================================================================= +# AgentMemoryIntegration Tests +# ============================================================================= + + +class TestAgentMemoryIntegration: + """Tests for AgentMemoryIntegration class.""" + + def test_create_integration(self): + """Test creating integration instance.""" + integration = AgentMemoryIntegration() + assert integration.trade_memory is not None + assert integration.risk_memory is not None + assert integration.situation_memory is not None + + def test_get_analyst_context_empty(self): + """Test getting analyst context with no history.""" + integration = AgentMemoryIntegration() + + context = integration.get_analyst_context( + symbol="AAPL", + current_situation="Tech sector showing strength", + analyst_type="momentum", + ) + + # Should return context (may be empty) + assert isinstance(context, MemoryContext) + + def test_get_analyst_context_with_trades(self): + """Test getting analyst context with trade history.""" + integration = AgentMemoryIntegration() + + # Add some trades + trade = TradeRecord.create( + symbol="AAPL", + direction=TradeDirection.LONG, + entry_price=150.0, + quantity=100, + ) + trade.close(exit_price=160.0) + integration.trade_memory.record_trade(trade) + + context = integration.get_analyst_context( + symbol="AAPL", + current_situation="Tech rally", + analyst_type="momentum", + ) + + assert len(context.raw_trades) > 0 + assert "AAPL" in context.trade_history + + def test_get_trader_context_empty(self): + """Test getting trader context with no history.""" + integration = AgentMemoryIntegration() + + context = integration.get_trader_context( + symbol="TSLA", + current_situation="EV sector momentum", + proposed_action="buy", + ) + + assert isinstance(context, MemoryContext) + + def test_get_trader_context_with_regime(self): + """Test getting trader context with market regime.""" + integration = AgentMemoryIntegration() + + # Set up a profile + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + integration.risk_memory.set_profile(profile) + + context = integration.get_trader_context( + symbol="AAPL", + current_situation="Bull market", + proposed_action="buy", + market_regime=MarketRegime.BULL, + ) + + assert "Recommended risk level" in context.risk_context + assert "moderate" in context.risk_context.lower() + + def test_get_risk_manager_context(self): + """Test getting risk manager context.""" + integration = AgentMemoryIntegration() + + # Add some trades with outcomes + for i in range(5): + trade = TradeRecord.create( + symbol="MSFT", + direction=TradeDirection.LONG, + entry_price=300.0 + i, + quantity=100, + ) + # Some winners, some losers + if i % 2 == 0: + trade.close(exit_price=310.0 + i) + else: + trade.close(exit_price=290.0 + i) + integration.trade_memory.record_trade(trade) + + context = integration.get_risk_manager_context( + symbol="MSFT", + proposed_trade="Buy 100 shares at $305", + position_size=30500, + market_regime=MarketRegime.BULL, + ) + + assert "Trading history for MSFT" in context.trade_history + assert "Win rate" in context.trade_history + + def test_record_trade_outcome(self): + """Test recording a trade outcome.""" + integration = AgentMemoryIntegration() + + trade = TradeRecord.create( + symbol="GOOGL", + direction=TradeDirection.LONG, + entry_price=140.0, + quantity=50, + ) + trade.close(exit_price=150.0) + + integration.record_trade_outcome( + trade=trade, + situation_context="Tech sector showing AI momentum", + lesson_learned="AI momentum trades tend to work well", + ) + + # Trade should be in memory + trades = integration.trade_memory.get_trades_by_symbol("GOOGL") + assert len(trades) == 1 + assert trades[0].symbol == "GOOGL" + + # Situation should be recorded + assert integration.situation_memory.count() == 1 + + def test_record_risk_decision(self): + """Test recording a risk decision.""" + integration = AgentMemoryIntegration() + + decision_id = integration.record_risk_decision( + category=RiskCategory.POSITION_SIZE, + risk_level=0.6, + market_regime=MarketRegime.BULL, + context="Strong momentum, increasing position", + ) + + assert decision_id is not None + + # Decision should be recorded + decision = integration.risk_memory.get_decision(decision_id) + assert decision is not None + assert decision.risk_level == 0.6 + + def test_evaluate_risk_decision(self): + """Test evaluating a risk decision.""" + integration = AgentMemoryIntegration() + + decision_id = integration.record_risk_decision( + category=RiskCategory.LEVERAGE, + risk_level=0.7, + market_regime=MarketRegime.LOW_VOLATILITY, + context="Low vol, using leverage", + ) + + integration.evaluate_risk_decision( + decision_id=decision_id, + outcome="Profitable trade with leverage", + outcome_score=0.6, + was_appropriate=True, + ) + + decision = integration.risk_memory.get_decision(decision_id) + assert decision.was_appropriate is True + assert decision.outcome_score == 0.6 + + def test_to_dict_and_from_dict(self): + """Test serialization roundtrip.""" + integration = AgentMemoryIntegration() + + # Add some data + trade = TradeRecord.create( + symbol="NVDA", + direction=TradeDirection.LONG, + entry_price=500.0, + quantity=20, + ) + trade.close(exit_price=550.0) + integration.trade_memory.record_trade(trade) + + profile = RiskProfile(user_id="test", base_tolerance=RiskTolerance.AGGRESSIVE) + integration.risk_memory.set_profile(profile) + + # Serialize and restore + data = integration.to_dict() + restored = AgentMemoryIntegration.from_dict(data) + + # Verify data preserved + trades = restored.trade_memory.get_trades_by_symbol("NVDA") + assert len(trades) == 1 + + profile = restored.risk_memory.get_profile("test") + assert profile is not None + assert profile.base_tolerance == RiskTolerance.AGGRESSIVE + + +# ============================================================================= +# Helper Function Tests +# ============================================================================= + + +class TestCreateMemoryEnhancedPrompt: + """Tests for create_memory_enhanced_prompt function.""" + + def test_empty_context_returns_base(self): + """Test that empty context returns base prompt unchanged.""" + base_prompt = "Analyze the stock" + context = MemoryContext() + + result = create_memory_enhanced_prompt(base_prompt, context) + + assert result == base_prompt + + def test_adds_memory_section(self): + """Test that memory section is added.""" + base_prompt = "Analyze the stock" + context = MemoryContext(trade_history="Previous trade: +5%") + + result = create_memory_enhanced_prompt(base_prompt, context) + + assert "Analyze the stock" in result + assert "Memory Context" in result + assert "Previous trade: +5%" in result + + def test_respects_context_types(self): + """Test filtering by context types.""" + base_prompt = "Analyze" + context = MemoryContext( + trade_history="Trade history", + risk_context="Risk context", + ) + + result = create_memory_enhanced_prompt( + base_prompt, + context, + [ContextType.TRADE_HISTORY], + ) + + assert "Trade history" in result + assert "Risk context" not in result + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestAgentMemoryIntegrationWorkflow: + """Integration tests for complete memory workflow.""" + + def test_full_trading_workflow(self): + """Test a complete trading workflow with memory.""" + integration = AgentMemoryIntegration() + + # 1. Set up user profile + profile = RiskProfile( + user_id="trader1", + base_tolerance=RiskTolerance.MODERATE, + ) + integration.risk_memory.set_profile(profile) + + # 2. Record some historical trades + historical_trades = [ + ("AAPL", TradeDirection.LONG, 150.0, 160.0), + ("AAPL", TradeDirection.LONG, 155.0, 165.0), + ("AAPL", TradeDirection.LONG, 160.0, 155.0), # Loss + ] + + for symbol, direction, entry, exit in historical_trades: + trade = TradeRecord.create( + symbol=symbol, + direction=direction, + entry_price=entry, + quantity=100, + ) + trade.close(exit_price=exit) + integration.record_trade_outcome( + trade=trade, + situation_context=f"Trade in {symbol}", + ) + + # 3. Get analyst context for new analysis + analyst_context = integration.get_analyst_context( + symbol="AAPL", + current_situation="Apple showing momentum", + analyst_type="momentum", + ) + + assert len(analyst_context.raw_trades) == 3 + assert "AAPL" in analyst_context.trade_history + + # 4. Get trader context for decision + trader_context = integration.get_trader_context( + symbol="AAPL", + current_situation="Strong momentum signal", + proposed_action="buy", + market_regime=MarketRegime.BULL, + user_id="trader1", + ) + + assert "Recommended risk level" in trader_context.risk_context + + # 5. Get risk manager context + risk_context = integration.get_risk_manager_context( + symbol="AAPL", + proposed_trade="Buy 100 shares", + position_size=16000, + market_regime=MarketRegime.BULL, + user_id="trader1", + ) + + assert "Win rate" in risk_context.trade_history + assert "moderate" in risk_context.risk_context.lower() + + def test_memory_influences_recommendations(self): + """Test that memory influences risk recommendations.""" + integration = AgentMemoryIntegration() + + # Set up profile + profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE) + integration.risk_memory.set_profile(profile) + + # Record successful high-risk decisions + for i in range(5): + decision_id = integration.record_risk_decision( + category=RiskCategory.POSITION_SIZE, + risk_level=0.7, + market_regime=MarketRegime.BULL, + context=f"Successful trade {i}", + ) + integration.evaluate_risk_decision( + decision_id=decision_id, + outcome="Profitable", + outcome_score=0.6, + was_appropriate=True, + ) + + # Get recommendation - should be influenced by history + risk_level, explanation = integration.risk_memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=MarketRegime.BULL, + context="Similar successful situation", + ) + + # With successful high-risk history, recommendation should be elevated + # Base moderate (0.375) + bull adjustment (0.1) = 0.475 + # But with history of successful 0.7 decisions, should be higher + assert risk_level > 0.5 + + def test_lessons_extracted_from_trades(self): + """Test that lessons are extracted from trade patterns.""" + integration = AgentMemoryIntegration() + + # Record winning trades with short hold times + for i in range(3): + trade = TradeRecord.create( + symbol="TSLA", + direction=TradeDirection.LONG, + entry_price=200.0, + quantity=50, + ) + # Quick win - close the trade + trade.close(exit_price=220.0) + integration.trade_memory.record_trade(trade) + + # Record losing trades with long hold times + for i in range(3): + trade = TradeRecord.create( + symbol="TSLA", + direction=TradeDirection.LONG, + entry_price=200.0, + quantity=50, + ) + # Slow loss + trade.close(exit_price=180.0) + integration.trade_memory.record_trade(trade) + + context = integration.get_analyst_context( + symbol="TSLA", + current_situation="EV sector", + analyst_type="general", + ) + + # Lessons should exist (may be strategy continuation or specific lesson) + assert context.lessons_learned != "" + + +class TestContextTypeFiltering: + """Tests for context type filtering.""" + + def test_all_types(self): + """Test ALL context type includes everything.""" + context = MemoryContext( + trade_history="History", + risk_context="Risk", + similar_situations="Similar", + lessons_learned="Lessons", + ) + + result = context.to_prompt_string([ContextType.ALL]) + + assert "History" in result + assert "Risk" in result + assert "Similar" in result + assert "Lessons" in result + + def test_single_type(self): + """Test filtering to single type.""" + context = MemoryContext( + trade_history="History", + risk_context="Risk", + ) + + result = context.to_prompt_string([ContextType.RISK_PROFILE]) + + assert "History" not in result + assert "Risk" in result + + def test_empty_sections_not_included(self): + """Test that empty sections aren't shown.""" + context = MemoryContext( + trade_history="History", + risk_context="", # Empty + ) + + result = context.to_prompt_string([ContextType.ALL]) + + assert "Recent Trade History" in result + assert "Risk Profile Context" not in result diff --git a/tradingagents/agents/utils/agent_states.py b/tradingagents/agents/utils/agent_states.py index eeb53303..24beb380 100644 --- a/tradingagents/agents/utils/agent_states.py +++ b/tradingagents/agents/utils/agent_states.py @@ -64,6 +64,11 @@ class AgentState(MessagesState): macro_report: Annotated[str, "Report from the Macro Analyst"] correlation_report: Annotated[str, "Report from the Correlation Analyst"] + # memory context (Issue #21) + memory_context: Annotated[str, "Memory context for agents including past trades and lessons"] + trade_history_context: Annotated[str, "Relevant past trade history for this ticker"] + risk_context: Annotated[str, "Risk profile context and recommendations"] + # researcher team discussion step investment_debate_state: Annotated[ InvestDebateState, "Current state of the debate on if to invest or not" diff --git a/tradingagents/memory/__init__.py b/tradingagents/memory/__init__.py index f4155d83..70490507 100644 --- a/tradingagents/memory/__init__.py +++ b/tradingagents/memory/__init__.py @@ -8,6 +8,7 @@ This module provides a layered memory system with three scoring dimensions: Issue #18: Layered memory - recency, relevancy, importance scoring Issue #19: Trade history memory - outcomes, agent reasoning Issue #20: Risk profiles memory - user preferences over time +Issue #21: Memory integration - retrieval in agent prompts """ from .layered_memory import ( @@ -38,6 +39,13 @@ from .risk_profiles import ( RiskCategory, ) +from .integration import ( + AgentMemoryIntegration, + MemoryContext, + ContextType, + create_memory_enhanced_prompt, +) + __all__ = [ # Layered Memory (Issue #18) "LayeredMemory", @@ -61,4 +69,9 @@ __all__ = [ "RiskTolerance", "MarketRegime", "RiskCategory", + # Memory Integration (Issue #21) + "AgentMemoryIntegration", + "MemoryContext", + "ContextType", + "create_memory_enhanced_prompt", ] diff --git a/tradingagents/memory/integration.py b/tradingagents/memory/integration.py new file mode 100644 index 00000000..279f9645 --- /dev/null +++ b/tradingagents/memory/integration.py @@ -0,0 +1,600 @@ +"""Memory integration for agent prompts. + +This module provides integration between the memory system and agent prompts, +enabling agents to access relevant historical context for better decision-making. + +Issue #21: [MEM-20] Memory integration - retrieval in agent prompts +""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any, Callable +from enum import Enum + +from .layered_memory import LayeredMemory, MemoryEntry, MemoryConfig, ScoringWeights +from .trade_history import TradeHistoryMemory, TradeRecord, TradeOutcome, TradeDirection +from .risk_profiles import ( + RiskProfileMemory, + RiskProfile, + RiskDecision, + RiskTolerance, + MarketRegime, + RiskCategory, +) + + +class ContextType(Enum): + """Types of memory context that can be retrieved.""" + TRADE_HISTORY = "trade_history" + RISK_PROFILE = "risk_profile" + SIMILAR_SITUATIONS = "similar_situations" + LESSONS_LEARNED = "lessons_learned" + ALL = "all" + + +@dataclass +class MemoryContext: + """Memory context for agent prompts. + + Attributes: + trade_history: Summary of relevant past trades + risk_context: Risk profile recommendations + similar_situations: Similar past situations and outcomes + lessons_learned: Key lessons from past trades + raw_trades: List of relevant TradeRecord objects + """ + trade_history: str = "" + risk_context: str = "" + similar_situations: str = "" + lessons_learned: str = "" + raw_trades: List[TradeRecord] = field(default_factory=list) + + def to_prompt_string(self, include_types: Optional[List[ContextType]] = None) -> str: + """Convert memory context to a string for agent prompts. + + Args: + include_types: Types of context to include (default: all non-empty) + + Returns: + Formatted string for agent prompts + """ + if include_types is None: + include_types = [ContextType.ALL] + + parts = [] + + if ContextType.ALL in include_types or ContextType.TRADE_HISTORY in include_types: + if self.trade_history: + parts.append(f"## Recent Trade History\n{self.trade_history}") + + if ContextType.ALL in include_types or ContextType.RISK_PROFILE in include_types: + if self.risk_context: + parts.append(f"## Risk Profile Context\n{self.risk_context}") + + if ContextType.ALL in include_types or ContextType.SIMILAR_SITUATIONS in include_types: + if self.similar_situations: + parts.append(f"## Similar Past Situations\n{self.similar_situations}") + + if ContextType.ALL in include_types or ContextType.LESSONS_LEARNED in include_types: + if self.lessons_learned: + parts.append(f"## Lessons Learned\n{self.lessons_learned}") + + if not parts: + return "No relevant memory context available." + + return "\n\n".join(parts) + + def is_empty(self) -> bool: + """Check if context is empty.""" + return not any([ + self.trade_history, + self.risk_context, + self.similar_situations, + self.lessons_learned, + ]) + + +class AgentMemoryIntegration: + """Integration layer between memory systems and agent prompts. + + This class provides methods to retrieve relevant memory context + for different agents in the trading system. + + Example: + >>> integration = AgentMemoryIntegration() + >>> + >>> # Get context for an analyst + >>> context = integration.get_analyst_context( + ... ticker="AAPL", + ... current_situation="Tech sector showing momentum", + ... analyst_type="momentum", + ... ) + >>> + >>> # Use in prompt + >>> prompt = f"Analyze {ticker}. Memory context: {context.to_prompt_string()}" + """ + + def __init__( + self, + trade_memory: Optional[TradeHistoryMemory] = None, + risk_memory: Optional[RiskProfileMemory] = None, + situation_memory: Optional[LayeredMemory] = None, + embedding_function: Optional[Callable] = None, + ): + """Initialize memory integration. + + Args: + trade_memory: Trade history memory instance + risk_memory: Risk profile memory instance + situation_memory: General situation memory + embedding_function: Optional embedding function for similarity + """ + self._trade_memory = trade_memory or TradeHistoryMemory() + self._risk_memory = risk_memory or RiskProfileMemory() + self._situation_memory = situation_memory or LayeredMemory( + config=MemoryConfig( + weights=ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2) + ), + embedding_function=embedding_function, + ) + self._embedding_function = embedding_function + + @property + def trade_memory(self) -> TradeHistoryMemory: + """Access trade history memory.""" + return self._trade_memory + + @property + def risk_memory(self) -> RiskProfileMemory: + """Access risk profile memory.""" + return self._risk_memory + + @property + def situation_memory(self) -> LayeredMemory: + """Access situation memory.""" + return self._situation_memory + + def get_analyst_context( + self, + symbol: str, + current_situation: str, + analyst_type: str = "general", + lookback_days: int = 90, + max_trades: int = 5, + user_id: Optional[str] = None, + ) -> MemoryContext: + """Get memory context for an analyst agent. + + Args: + symbol: Stock symbol being analyzed + current_situation: Current market situation description + analyst_type: Type of analyst (momentum, macro, etc.) + lookback_days: Days to look back for trades + max_trades: Maximum trades to include + user_id: User ID for risk profile + + Returns: + MemoryContext with relevant information + """ + context = MemoryContext() + + # Get relevant past trades for this symbol + trades = self._trade_memory.get_trades_by_symbol(symbol) + + if trades: + # Filter to lookback period + cutoff = datetime.now() - timedelta(days=lookback_days) + recent_trades = [t for t in trades if t.entry_time >= cutoff][:max_trades] + + if recent_trades: + context.raw_trades = recent_trades + context.trade_history = self._format_trade_history(recent_trades) + context.lessons_learned = self._extract_lessons(recent_trades) + + # Get similar situations + similar = self._situation_memory.retrieve( + query=current_situation, + top_k=3, + tags=[analyst_type] if analyst_type != "general" else None, + ) + + if similar: + context.similar_situations = self._format_similar_situations(similar) + + return context + + def get_trader_context( + self, + symbol: str, + current_situation: str, + proposed_action: str, + market_regime: Optional[MarketRegime] = None, + user_id: Optional[str] = None, + ) -> MemoryContext: + """Get memory context for the trader agent. + + Args: + symbol: Stock symbol being traded + current_situation: Current market situation + proposed_action: Proposed trade action (buy/sell/hold) + market_regime: Current market regime + user_id: User ID for risk profile + + Returns: + MemoryContext with relevant information + """ + context = MemoryContext() + + # Get past trades for this symbol + trades = self._trade_memory.get_trades_by_symbol(symbol) + if trades: + context.raw_trades = trades[:5] + context.trade_history = self._format_trade_history(trades[:5]) + context.lessons_learned = self._extract_lessons(trades) + + # Get risk profile context + if market_regime: + risk_level, explanation = self._risk_memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=market_regime, + context=current_situation, + user_id=user_id, + ) + profile = self._risk_memory.get_or_create_profile(user_id) + context.risk_context = ( + f"Recommended risk level: {risk_level:.2f}\n" + f"Base tolerance: {profile.base_tolerance.value}\n" + f"Reasoning: {explanation}" + ) + + # Get similar trading situations + similar = self._situation_memory.retrieve( + query=f"{current_situation} {proposed_action}", + top_k=3, + ) + + if similar: + context.similar_situations = self._format_similar_situations(similar) + + return context + + def get_risk_manager_context( + self, + symbol: str, + proposed_trade: str, + position_size: float, + market_regime: Optional[MarketRegime] = None, + user_id: Optional[str] = None, + ) -> MemoryContext: + """Get memory context for risk management agent. + + Args: + symbol: Stock symbol + proposed_trade: Proposed trade description + position_size: Proposed position size + market_regime: Current market regime + user_id: User ID + + Returns: + MemoryContext with risk-focused information + """ + context = MemoryContext() + + # Get past trades with outcome statistics + trades = self._trade_memory.get_trades_by_symbol(symbol) + if trades: + winning = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE] + losing = [t for t in trades if t.outcome == TradeOutcome.LOSS] + + win_rate = len(winning) / len(trades) if trades else 0 + avg_return = sum( + t.returns or 0 for t in trades if t.returns + ) / max(1, len([t for t in trades if t.returns])) + + context.trade_history = ( + f"Trading history for {symbol}:\n" + f"- Total trades: {len(trades)}\n" + f"- Win rate: {win_rate:.1%}\n" + f"- Average return: {avg_return:.2%}\n" + f"- Winners: {len(winning)}, Losers: {len(losing)}" + ) + + # Extract risk lessons + context.lessons_learned = self._extract_risk_lessons(trades) + + # Get risk profile and recommendations + if market_regime: + profile = self._risk_memory.get_or_create_profile(user_id) + adjusted_tolerance = profile.get_adjusted_tolerance(market_regime) + + risk_level, explanation = self._risk_memory.recommend_risk_level( + category=RiskCategory.POSITION_SIZE, + market_regime=market_regime, + context=proposed_trade, + user_id=user_id, + ) + + context.risk_context = ( + f"User risk profile:\n" + f"- Base tolerance: {profile.base_tolerance.value}\n" + f"- Adjusted for {market_regime.value}: {adjusted_tolerance.value}\n" + f"- Max drawdown tolerance: {profile.max_drawdown_tolerance:.1%}\n" + f"- Recommended risk level: {risk_level:.2f}\n" + f"- Reasoning: {explanation}" + ) + + return context + + def record_trade_outcome( + self, + trade: TradeRecord, + situation_context: str, + lesson_learned: Optional[str] = None, + ) -> None: + """Record a trade outcome for future reference. + + Args: + trade: The completed trade record + situation_context: Description of the market situation + lesson_learned: Optional lesson to remember + """ + # Record in trade memory + self._trade_memory.record_trade(trade) + + # Record situation for future similarity matching + importance = 0.5 + if trade.returns: + # Higher importance for significant outcomes + importance = min(1.0, 0.5 + abs(trade.returns)) + + entry_content = ( + f"Trade: {trade.direction.value} {trade.symbol} at {trade.entry_price}. " + f"Outcome: {trade.outcome.value if trade.outcome else 'pending'}. " + f"Context: {situation_context}" + ) + + if lesson_learned: + entry_content += f"\nLesson: {lesson_learned}" + + entry = MemoryEntry.create( + content=entry_content, + metadata={ + "trade_id": trade.id, + "symbol": trade.symbol, + "direction": trade.direction.value, + "outcome": trade.outcome.value if trade.outcome else None, + "return": trade.returns, + }, + importance=importance, + tags=[trade.symbol, trade.direction.value], + ) + + self._situation_memory.add(entry) + + def record_risk_decision( + self, + category: RiskCategory, + risk_level: float, + market_regime: MarketRegime, + context: str, + user_id: Optional[str] = None, + ) -> str: + """Record a risk decision for learning. + + Args: + category: Risk category + risk_level: Risk level chosen + market_regime: Current market regime + context: Decision context + user_id: User ID + + Returns: + Decision ID + """ + decision = RiskDecision.create( + category=category, + risk_level=risk_level, + market_regime=market_regime, + context=context, + ) + + return self._risk_memory.record_decision(decision, user_id) + + def evaluate_risk_decision( + self, + decision_id: str, + outcome: str, + outcome_score: float, + was_appropriate: bool, + ) -> None: + """Evaluate a past risk decision. + + Args: + decision_id: Decision ID to evaluate + outcome: What happened + outcome_score: Outcome score (-1 to 1) + was_appropriate: Whether decision was appropriate + """ + self._risk_memory.evaluate_decision( + decision_id=decision_id, + outcome=outcome, + outcome_score=outcome_score, + was_appropriate=was_appropriate, + ) + + def _format_trade_history(self, trades: List[TradeRecord]) -> str: + """Format trade history for prompts.""" + if not trades: + return "No recent trades." + + lines = [] + for trade in trades: + outcome = trade.outcome.value if trade.outcome else "pending" + ret = f"{trade.returns:+.2%}" if trade.returns else "N/A" + lines.append( + f"- {trade.entry_time.strftime('%Y-%m-%d')}: " + f"{trade.direction.value.upper()} {trade.symbol} @ ${trade.entry_price:.2f} " + f"-> {outcome} ({ret})" + ) + + return "\n".join(lines) + + def _format_similar_situations(self, scored_entries) -> str: + """Format similar situations for prompts.""" + if not scored_entries: + return "No similar past situations found." + + lines = [] + for scored in scored_entries[:3]: + entry = scored.entry + score = scored.combined_score + lines.append(f"- (relevance: {score:.2f}) {entry.content[:200]}...") + + return "\n".join(lines) + + def _extract_lessons(self, trades: List[TradeRecord]) -> str: + """Extract lessons learned from trades.""" + if not trades: + return "No lessons to extract." + + lessons = [] + + # Analyze winning vs losing trades + winners = [t for t in trades if t.outcome == TradeOutcome.PROFITABLE] + losers = [t for t in trades if t.outcome == TradeOutcome.LOSS] + + if winners and losers: + win_avg_hold = sum( + (t.exit_time - t.entry_time).days + for t in winners if t.exit_time + ) / max(1, len([t for t in winners if t.exit_time])) + + loss_avg_hold = sum( + (t.exit_time - t.entry_time).days + for t in losers if t.exit_time + ) / max(1, len([t for t in losers if t.exit_time])) + + if win_avg_hold < loss_avg_hold: + lessons.append("Winners tend to show profits quickly; consider cutting losers earlier.") + elif loss_avg_hold < win_avg_hold: + lessons.append("Holding winners longer has been profitable; avoid taking profits too early.") + + # Look for patterns in agent reasoning + for trade in trades: + if trade.reasoning and trade.outcome: + if trade.outcome == TradeOutcome.PROFITABLE: + if trade.reasoning.research_conclusion: + lessons.append("Trades following analyst conclusions have been profitable.") + break + elif trade.outcome == TradeOutcome.LOSS: + if trade.reasoning.risk_assessment: + lessons.append("Consider risk assessment more carefully on future trades.") + break + + if not lessons: + return "Continue following current strategy." + + return "\n".join(f"- {lesson}" for lesson in lessons[:3]) + + def _extract_risk_lessons(self, trades: List[TradeRecord]) -> str: + """Extract risk-specific lessons from trades.""" + if not trades: + return "No risk lessons available." + + lessons = [] + + # Analyze large losses + large_losses = [ + t for t in trades + if t.returns and t.returns < -0.1 + ] + + if large_losses: + lessons.append( + f"Had {len(large_losses)} trades with >10% losses. " + "Consider tighter stop-losses." + ) + + # Check for position sizing patterns + trades_with_size = [t for t in trades if t.quantity] + if trades_with_size: + large_positions = [ + t for t in trades_with_size + if t.quantity > 100 and t.outcome == TradeOutcome.LOSS + ] + if large_positions: + lessons.append( + "Larger positions have shown higher loss frequency. " + "Consider scaling in gradually." + ) + + if not lessons: + return "No specific risk warnings from recent history." + + return "\n".join(f"- {lesson}" for lesson in lessons) + + def to_dict(self) -> Dict[str, Any]: + """Serialize to dictionary.""" + return { + "trade_memory": self._trade_memory.to_dict(), + "risk_memory": self._risk_memory.to_dict(), + "situation_memory": self._situation_memory.to_dict(), + } + + @classmethod + def from_dict( + cls, + data: Dict[str, Any], + embedding_function: Optional[Callable] = None, + ) -> "AgentMemoryIntegration": + """Create from dictionary.""" + instance = cls(embedding_function=embedding_function) + + if "trade_memory" in data: + instance._trade_memory = TradeHistoryMemory.from_dict( + data["trade_memory"], + embedding_function=embedding_function, + ) + + if "risk_memory" in data: + instance._risk_memory = RiskProfileMemory.from_dict( + data["risk_memory"], + embedding_function=embedding_function, + ) + + if "situation_memory" in data: + instance._situation_memory = LayeredMemory.from_dict( + data["situation_memory"], + embedding_function=embedding_function, + ) + + return instance + + +def create_memory_enhanced_prompt( + base_prompt: str, + context: MemoryContext, + context_types: Optional[List[ContextType]] = None, +) -> str: + """Create a memory-enhanced prompt from a base prompt. + + Args: + base_prompt: Original agent prompt + context: Memory context to include + context_types: Types of context to include + + Returns: + Enhanced prompt with memory context + """ + if context.is_empty(): + return base_prompt + + memory_section = context.to_prompt_string(context_types) + + return ( + f"{base_prompt}\n\n" + f"---\n" + f"# Memory Context (Use this to inform your analysis)\n\n" + f"{memory_section}\n" + f"---" + )