TradingAgents/tests/unit/memory/test_integration.py

549 lines
18 KiB
Python

"""Tests for Memory Integration module.
Issue #21: [MEM-20] Memory integration - retrieval in agent prompts
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from tradingagents.memory.integration import (
AgentMemoryIntegration,
MemoryContext,
ContextType,
create_memory_enhanced_prompt,
)
from tradingagents.memory.trade_history import (
TradeRecord,
TradeOutcome,
TradeDirection,
SignalStrength,
AgentReasoning,
)
from tradingagents.memory.risk_profiles import (
RiskCategory,
MarketRegime,
RiskTolerance,
RiskProfile,
)
# =============================================================================
# MemoryContext Tests
# =============================================================================
class TestMemoryContext:
"""Tests for MemoryContext dataclass."""
def test_empty_context(self):
"""Test empty context detection."""
context = MemoryContext()
assert context.is_empty()
def test_non_empty_context(self):
"""Test non-empty context detection."""
context = MemoryContext(trade_history="Some trade history")
assert not context.is_empty()
def test_to_prompt_string_empty(self):
"""Test prompt string for empty context."""
context = MemoryContext()
result = context.to_prompt_string()
assert "No relevant memory context" in result
def test_to_prompt_string_with_history(self):
"""Test prompt string with trade history."""
context = MemoryContext(trade_history="AAPL: +5% last week")
result = context.to_prompt_string()
assert "Recent Trade History" in result
assert "AAPL: +5% last week" in result
def test_to_prompt_string_with_all_sections(self):
"""Test prompt string with all sections."""
context = MemoryContext(
trade_history="Trade history content",
risk_context="Risk context content",
similar_situations="Similar situations content",
lessons_learned="Lessons learned content",
)
result = context.to_prompt_string()
assert "Recent Trade History" in result
assert "Risk Profile Context" in result
assert "Similar Past Situations" in result
assert "Lessons Learned" in result
def test_to_prompt_string_filter_by_type(self):
"""Test filtering context by type."""
context = MemoryContext(
trade_history="Trade history content",
risk_context="Risk context content",
lessons_learned="Lessons learned content",
)
# Only trade history
result = context.to_prompt_string([ContextType.TRADE_HISTORY])
assert "Recent Trade History" in result
assert "Risk Profile Context" not in result
assert "Lessons Learned" not in result
def test_to_prompt_string_multiple_types(self):
"""Test filtering with multiple types."""
context = MemoryContext(
trade_history="Trade history content",
risk_context="Risk context content",
lessons_learned="Lessons learned content",
)
result = context.to_prompt_string([
ContextType.TRADE_HISTORY,
ContextType.RISK_PROFILE,
])
assert "Recent Trade History" in result
assert "Risk Profile Context" in result
assert "Lessons Learned" not in result
# =============================================================================
# AgentMemoryIntegration Tests
# =============================================================================
class TestAgentMemoryIntegration:
"""Tests for AgentMemoryIntegration class."""
def test_create_integration(self):
"""Test creating integration instance."""
integration = AgentMemoryIntegration()
assert integration.trade_memory is not None
assert integration.risk_memory is not None
assert integration.situation_memory is not None
def test_get_analyst_context_empty(self):
"""Test getting analyst context with no history."""
integration = AgentMemoryIntegration()
context = integration.get_analyst_context(
symbol="AAPL",
current_situation="Tech sector showing strength",
analyst_type="momentum",
)
# Should return context (may be empty)
assert isinstance(context, MemoryContext)
def test_get_analyst_context_with_trades(self):
"""Test getting analyst context with trade history."""
integration = AgentMemoryIntegration()
# Add some trades
trade = TradeRecord.create(
symbol="AAPL",
direction=TradeDirection.LONG,
entry_price=150.0,
quantity=100,
)
trade.close(exit_price=160.0)
integration.trade_memory.record_trade(trade)
context = integration.get_analyst_context(
symbol="AAPL",
current_situation="Tech rally",
analyst_type="momentum",
)
assert len(context.raw_trades) > 0
assert "AAPL" in context.trade_history
def test_get_trader_context_empty(self):
"""Test getting trader context with no history."""
integration = AgentMemoryIntegration()
context = integration.get_trader_context(
symbol="TSLA",
current_situation="EV sector momentum",
proposed_action="buy",
)
assert isinstance(context, MemoryContext)
def test_get_trader_context_with_regime(self):
"""Test getting trader context with market regime."""
integration = AgentMemoryIntegration()
# Set up a profile
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
integration.risk_memory.set_profile(profile)
context = integration.get_trader_context(
symbol="AAPL",
current_situation="Bull market",
proposed_action="buy",
market_regime=MarketRegime.BULL,
)
assert "Recommended risk level" in context.risk_context
assert "moderate" in context.risk_context.lower()
def test_get_risk_manager_context(self):
"""Test getting risk manager context."""
integration = AgentMemoryIntegration()
# Add some trades with outcomes
for i in range(5):
trade = TradeRecord.create(
symbol="MSFT",
direction=TradeDirection.LONG,
entry_price=300.0 + i,
quantity=100,
)
# Some winners, some losers
if i % 2 == 0:
trade.close(exit_price=310.0 + i)
else:
trade.close(exit_price=290.0 + i)
integration.trade_memory.record_trade(trade)
context = integration.get_risk_manager_context(
symbol="MSFT",
proposed_trade="Buy 100 shares at $305",
position_size=30500,
market_regime=MarketRegime.BULL,
)
assert "Trading history for MSFT" in context.trade_history
assert "Win rate" in context.trade_history
def test_record_trade_outcome(self):
"""Test recording a trade outcome."""
integration = AgentMemoryIntegration()
trade = TradeRecord.create(
symbol="GOOGL",
direction=TradeDirection.LONG,
entry_price=140.0,
quantity=50,
)
trade.close(exit_price=150.0)
integration.record_trade_outcome(
trade=trade,
situation_context="Tech sector showing AI momentum",
lesson_learned="AI momentum trades tend to work well",
)
# Trade should be in memory
trades = integration.trade_memory.get_trades_by_symbol("GOOGL")
assert len(trades) == 1
assert trades[0].symbol == "GOOGL"
# Situation should be recorded
assert integration.situation_memory.count() == 1
def test_record_risk_decision(self):
"""Test recording a risk decision."""
integration = AgentMemoryIntegration()
decision_id = integration.record_risk_decision(
category=RiskCategory.POSITION_SIZE,
risk_level=0.6,
market_regime=MarketRegime.BULL,
context="Strong momentum, increasing position",
)
assert decision_id is not None
# Decision should be recorded
decision = integration.risk_memory.get_decision(decision_id)
assert decision is not None
assert decision.risk_level == 0.6
def test_evaluate_risk_decision(self):
"""Test evaluating a risk decision."""
integration = AgentMemoryIntegration()
decision_id = integration.record_risk_decision(
category=RiskCategory.LEVERAGE,
risk_level=0.7,
market_regime=MarketRegime.LOW_VOLATILITY,
context="Low vol, using leverage",
)
integration.evaluate_risk_decision(
decision_id=decision_id,
outcome="Profitable trade with leverage",
outcome_score=0.6,
was_appropriate=True,
)
decision = integration.risk_memory.get_decision(decision_id)
assert decision.was_appropriate is True
assert decision.outcome_score == 0.6
def test_to_dict_and_from_dict(self):
"""Test serialization roundtrip."""
integration = AgentMemoryIntegration()
# Add some data
trade = TradeRecord.create(
symbol="NVDA",
direction=TradeDirection.LONG,
entry_price=500.0,
quantity=20,
)
trade.close(exit_price=550.0)
integration.trade_memory.record_trade(trade)
profile = RiskProfile(user_id="test", base_tolerance=RiskTolerance.AGGRESSIVE)
integration.risk_memory.set_profile(profile)
# Serialize and restore
data = integration.to_dict()
restored = AgentMemoryIntegration.from_dict(data)
# Verify data preserved
trades = restored.trade_memory.get_trades_by_symbol("NVDA")
assert len(trades) == 1
profile = restored.risk_memory.get_profile("test")
assert profile is not None
assert profile.base_tolerance == RiskTolerance.AGGRESSIVE
# =============================================================================
# Helper Function Tests
# =============================================================================
class TestCreateMemoryEnhancedPrompt:
"""Tests for create_memory_enhanced_prompt function."""
def test_empty_context_returns_base(self):
"""Test that empty context returns base prompt unchanged."""
base_prompt = "Analyze the stock"
context = MemoryContext()
result = create_memory_enhanced_prompt(base_prompt, context)
assert result == base_prompt
def test_adds_memory_section(self):
"""Test that memory section is added."""
base_prompt = "Analyze the stock"
context = MemoryContext(trade_history="Previous trade: +5%")
result = create_memory_enhanced_prompt(base_prompt, context)
assert "Analyze the stock" in result
assert "Memory Context" in result
assert "Previous trade: +5%" in result
def test_respects_context_types(self):
"""Test filtering by context types."""
base_prompt = "Analyze"
context = MemoryContext(
trade_history="Trade history",
risk_context="Risk context",
)
result = create_memory_enhanced_prompt(
base_prompt,
context,
[ContextType.TRADE_HISTORY],
)
assert "Trade history" in result
assert "Risk context" not in result
# =============================================================================
# Integration Tests
# =============================================================================
class TestAgentMemoryIntegrationWorkflow:
"""Integration tests for complete memory workflow."""
def test_full_trading_workflow(self):
"""Test a complete trading workflow with memory."""
integration = AgentMemoryIntegration()
# 1. Set up user profile
profile = RiskProfile(
user_id="trader1",
base_tolerance=RiskTolerance.MODERATE,
)
integration.risk_memory.set_profile(profile)
# 2. Record some historical trades
historical_trades = [
("AAPL", TradeDirection.LONG, 150.0, 160.0),
("AAPL", TradeDirection.LONG, 155.0, 165.0),
("AAPL", TradeDirection.LONG, 160.0, 155.0), # Loss
]
for symbol, direction, entry, exit in historical_trades:
trade = TradeRecord.create(
symbol=symbol,
direction=direction,
entry_price=entry,
quantity=100,
)
trade.close(exit_price=exit)
integration.record_trade_outcome(
trade=trade,
situation_context=f"Trade in {symbol}",
)
# 3. Get analyst context for new analysis
analyst_context = integration.get_analyst_context(
symbol="AAPL",
current_situation="Apple showing momentum",
analyst_type="momentum",
)
assert len(analyst_context.raw_trades) == 3
assert "AAPL" in analyst_context.trade_history
# 4. Get trader context for decision
trader_context = integration.get_trader_context(
symbol="AAPL",
current_situation="Strong momentum signal",
proposed_action="buy",
market_regime=MarketRegime.BULL,
user_id="trader1",
)
assert "Recommended risk level" in trader_context.risk_context
# 5. Get risk manager context
risk_context = integration.get_risk_manager_context(
symbol="AAPL",
proposed_trade="Buy 100 shares",
position_size=16000,
market_regime=MarketRegime.BULL,
user_id="trader1",
)
assert "Win rate" in risk_context.trade_history
assert "moderate" in risk_context.risk_context.lower()
def test_memory_influences_recommendations(self):
"""Test that memory influences risk recommendations."""
integration = AgentMemoryIntegration()
# Set up profile
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
integration.risk_memory.set_profile(profile)
# Record successful high-risk decisions
for i in range(5):
decision_id = integration.record_risk_decision(
category=RiskCategory.POSITION_SIZE,
risk_level=0.7,
market_regime=MarketRegime.BULL,
context=f"Successful trade {i}",
)
integration.evaluate_risk_decision(
decision_id=decision_id,
outcome="Profitable",
outcome_score=0.6,
was_appropriate=True,
)
# Get recommendation - should be influenced by history
risk_level, explanation = integration.risk_memory.recommend_risk_level(
category=RiskCategory.POSITION_SIZE,
market_regime=MarketRegime.BULL,
context="Similar successful situation",
)
# With successful high-risk history, recommendation should be elevated
# Base moderate (0.375) + bull adjustment (0.1) = 0.475
# But with history of successful 0.7 decisions, should be higher
assert risk_level > 0.5
def test_lessons_extracted_from_trades(self):
"""Test that lessons are extracted from trade patterns."""
integration = AgentMemoryIntegration()
# Record winning trades with short hold times
for i in range(3):
trade = TradeRecord.create(
symbol="TSLA",
direction=TradeDirection.LONG,
entry_price=200.0,
quantity=50,
)
# Quick win - close the trade
trade.close(exit_price=220.0)
integration.trade_memory.record_trade(trade)
# Record losing trades with long hold times
for i in range(3):
trade = TradeRecord.create(
symbol="TSLA",
direction=TradeDirection.LONG,
entry_price=200.0,
quantity=50,
)
# Slow loss
trade.close(exit_price=180.0)
integration.trade_memory.record_trade(trade)
context = integration.get_analyst_context(
symbol="TSLA",
current_situation="EV sector",
analyst_type="general",
)
# Lessons should exist (may be strategy continuation or specific lesson)
assert context.lessons_learned != ""
class TestContextTypeFiltering:
"""Tests for context type filtering."""
def test_all_types(self):
"""Test ALL context type includes everything."""
context = MemoryContext(
trade_history="History",
risk_context="Risk",
similar_situations="Similar",
lessons_learned="Lessons",
)
result = context.to_prompt_string([ContextType.ALL])
assert "History" in result
assert "Risk" in result
assert "Similar" in result
assert "Lessons" in result
def test_single_type(self):
"""Test filtering to single type."""
context = MemoryContext(
trade_history="History",
risk_context="Risk",
)
result = context.to_prompt_string([ContextType.RISK_PROFILE])
assert "History" not in result
assert "Risk" in result
def test_empty_sections_not_included(self):
"""Test that empty sections aren't shown."""
context = MemoryContext(
trade_history="History",
risk_context="", # Empty
)
result = context.to_prompt_string([ContextType.ALL])
assert "Recent Trade History" in result
assert "Risk Profile Context" not in result