549 lines
18 KiB
Python
549 lines
18 KiB
Python
"""Tests for Memory Integration module.
|
|
|
|
Issue #21: [MEM-20] Memory integration - retrieval in agent prompts
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import Mock, patch
|
|
|
|
from tradingagents.memory.integration import (
|
|
AgentMemoryIntegration,
|
|
MemoryContext,
|
|
ContextType,
|
|
create_memory_enhanced_prompt,
|
|
)
|
|
from tradingagents.memory.trade_history import (
|
|
TradeRecord,
|
|
TradeOutcome,
|
|
TradeDirection,
|
|
SignalStrength,
|
|
AgentReasoning,
|
|
)
|
|
from tradingagents.memory.risk_profiles import (
|
|
RiskCategory,
|
|
MarketRegime,
|
|
RiskTolerance,
|
|
RiskProfile,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# MemoryContext Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestMemoryContext:
|
|
"""Tests for MemoryContext dataclass."""
|
|
|
|
def test_empty_context(self):
|
|
"""Test empty context detection."""
|
|
context = MemoryContext()
|
|
assert context.is_empty()
|
|
|
|
def test_non_empty_context(self):
|
|
"""Test non-empty context detection."""
|
|
context = MemoryContext(trade_history="Some trade history")
|
|
assert not context.is_empty()
|
|
|
|
def test_to_prompt_string_empty(self):
|
|
"""Test prompt string for empty context."""
|
|
context = MemoryContext()
|
|
result = context.to_prompt_string()
|
|
assert "No relevant memory context" in result
|
|
|
|
def test_to_prompt_string_with_history(self):
|
|
"""Test prompt string with trade history."""
|
|
context = MemoryContext(trade_history="AAPL: +5% last week")
|
|
result = context.to_prompt_string()
|
|
|
|
assert "Recent Trade History" in result
|
|
assert "AAPL: +5% last week" in result
|
|
|
|
def test_to_prompt_string_with_all_sections(self):
|
|
"""Test prompt string with all sections."""
|
|
context = MemoryContext(
|
|
trade_history="Trade history content",
|
|
risk_context="Risk context content",
|
|
similar_situations="Similar situations content",
|
|
lessons_learned="Lessons learned content",
|
|
)
|
|
result = context.to_prompt_string()
|
|
|
|
assert "Recent Trade History" in result
|
|
assert "Risk Profile Context" in result
|
|
assert "Similar Past Situations" in result
|
|
assert "Lessons Learned" in result
|
|
|
|
def test_to_prompt_string_filter_by_type(self):
|
|
"""Test filtering context by type."""
|
|
context = MemoryContext(
|
|
trade_history="Trade history content",
|
|
risk_context="Risk context content",
|
|
lessons_learned="Lessons learned content",
|
|
)
|
|
|
|
# Only trade history
|
|
result = context.to_prompt_string([ContextType.TRADE_HISTORY])
|
|
assert "Recent Trade History" in result
|
|
assert "Risk Profile Context" not in result
|
|
assert "Lessons Learned" not in result
|
|
|
|
def test_to_prompt_string_multiple_types(self):
|
|
"""Test filtering with multiple types."""
|
|
context = MemoryContext(
|
|
trade_history="Trade history content",
|
|
risk_context="Risk context content",
|
|
lessons_learned="Lessons learned content",
|
|
)
|
|
|
|
result = context.to_prompt_string([
|
|
ContextType.TRADE_HISTORY,
|
|
ContextType.RISK_PROFILE,
|
|
])
|
|
|
|
assert "Recent Trade History" in result
|
|
assert "Risk Profile Context" in result
|
|
assert "Lessons Learned" not in result
|
|
|
|
|
|
# =============================================================================
|
|
# AgentMemoryIntegration Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestAgentMemoryIntegration:
|
|
"""Tests for AgentMemoryIntegration class."""
|
|
|
|
def test_create_integration(self):
|
|
"""Test creating integration instance."""
|
|
integration = AgentMemoryIntegration()
|
|
assert integration.trade_memory is not None
|
|
assert integration.risk_memory is not None
|
|
assert integration.situation_memory is not None
|
|
|
|
def test_get_analyst_context_empty(self):
|
|
"""Test getting analyst context with no history."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
context = integration.get_analyst_context(
|
|
symbol="AAPL",
|
|
current_situation="Tech sector showing strength",
|
|
analyst_type="momentum",
|
|
)
|
|
|
|
# Should return context (may be empty)
|
|
assert isinstance(context, MemoryContext)
|
|
|
|
def test_get_analyst_context_with_trades(self):
|
|
"""Test getting analyst context with trade history."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# Add some trades
|
|
trade = TradeRecord.create(
|
|
symbol="AAPL",
|
|
direction=TradeDirection.LONG,
|
|
entry_price=150.0,
|
|
quantity=100,
|
|
)
|
|
trade.close(exit_price=160.0)
|
|
integration.trade_memory.record_trade(trade)
|
|
|
|
context = integration.get_analyst_context(
|
|
symbol="AAPL",
|
|
current_situation="Tech rally",
|
|
analyst_type="momentum",
|
|
)
|
|
|
|
assert len(context.raw_trades) > 0
|
|
assert "AAPL" in context.trade_history
|
|
|
|
def test_get_trader_context_empty(self):
|
|
"""Test getting trader context with no history."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
context = integration.get_trader_context(
|
|
symbol="TSLA",
|
|
current_situation="EV sector momentum",
|
|
proposed_action="buy",
|
|
)
|
|
|
|
assert isinstance(context, MemoryContext)
|
|
|
|
def test_get_trader_context_with_regime(self):
|
|
"""Test getting trader context with market regime."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# Set up a profile
|
|
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
|
integration.risk_memory.set_profile(profile)
|
|
|
|
context = integration.get_trader_context(
|
|
symbol="AAPL",
|
|
current_situation="Bull market",
|
|
proposed_action="buy",
|
|
market_regime=MarketRegime.BULL,
|
|
)
|
|
|
|
assert "Recommended risk level" in context.risk_context
|
|
assert "moderate" in context.risk_context.lower()
|
|
|
|
def test_get_risk_manager_context(self):
|
|
"""Test getting risk manager context."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# Add some trades with outcomes
|
|
for i in range(5):
|
|
trade = TradeRecord.create(
|
|
symbol="MSFT",
|
|
direction=TradeDirection.LONG,
|
|
entry_price=300.0 + i,
|
|
quantity=100,
|
|
)
|
|
# Some winners, some losers
|
|
if i % 2 == 0:
|
|
trade.close(exit_price=310.0 + i)
|
|
else:
|
|
trade.close(exit_price=290.0 + i)
|
|
integration.trade_memory.record_trade(trade)
|
|
|
|
context = integration.get_risk_manager_context(
|
|
symbol="MSFT",
|
|
proposed_trade="Buy 100 shares at $305",
|
|
position_size=30500,
|
|
market_regime=MarketRegime.BULL,
|
|
)
|
|
|
|
assert "Trading history for MSFT" in context.trade_history
|
|
assert "Win rate" in context.trade_history
|
|
|
|
def test_record_trade_outcome(self):
|
|
"""Test recording a trade outcome."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
trade = TradeRecord.create(
|
|
symbol="GOOGL",
|
|
direction=TradeDirection.LONG,
|
|
entry_price=140.0,
|
|
quantity=50,
|
|
)
|
|
trade.close(exit_price=150.0)
|
|
|
|
integration.record_trade_outcome(
|
|
trade=trade,
|
|
situation_context="Tech sector showing AI momentum",
|
|
lesson_learned="AI momentum trades tend to work well",
|
|
)
|
|
|
|
# Trade should be in memory
|
|
trades = integration.trade_memory.get_trades_by_symbol("GOOGL")
|
|
assert len(trades) == 1
|
|
assert trades[0].symbol == "GOOGL"
|
|
|
|
# Situation should be recorded
|
|
assert integration.situation_memory.count() == 1
|
|
|
|
def test_record_risk_decision(self):
|
|
"""Test recording a risk decision."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
decision_id = integration.record_risk_decision(
|
|
category=RiskCategory.POSITION_SIZE,
|
|
risk_level=0.6,
|
|
market_regime=MarketRegime.BULL,
|
|
context="Strong momentum, increasing position",
|
|
)
|
|
|
|
assert decision_id is not None
|
|
|
|
# Decision should be recorded
|
|
decision = integration.risk_memory.get_decision(decision_id)
|
|
assert decision is not None
|
|
assert decision.risk_level == 0.6
|
|
|
|
def test_evaluate_risk_decision(self):
|
|
"""Test evaluating a risk decision."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
decision_id = integration.record_risk_decision(
|
|
category=RiskCategory.LEVERAGE,
|
|
risk_level=0.7,
|
|
market_regime=MarketRegime.LOW_VOLATILITY,
|
|
context="Low vol, using leverage",
|
|
)
|
|
|
|
integration.evaluate_risk_decision(
|
|
decision_id=decision_id,
|
|
outcome="Profitable trade with leverage",
|
|
outcome_score=0.6,
|
|
was_appropriate=True,
|
|
)
|
|
|
|
decision = integration.risk_memory.get_decision(decision_id)
|
|
assert decision.was_appropriate is True
|
|
assert decision.outcome_score == 0.6
|
|
|
|
def test_to_dict_and_from_dict(self):
|
|
"""Test serialization roundtrip."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# Add some data
|
|
trade = TradeRecord.create(
|
|
symbol="NVDA",
|
|
direction=TradeDirection.LONG,
|
|
entry_price=500.0,
|
|
quantity=20,
|
|
)
|
|
trade.close(exit_price=550.0)
|
|
integration.trade_memory.record_trade(trade)
|
|
|
|
profile = RiskProfile(user_id="test", base_tolerance=RiskTolerance.AGGRESSIVE)
|
|
integration.risk_memory.set_profile(profile)
|
|
|
|
# Serialize and restore
|
|
data = integration.to_dict()
|
|
restored = AgentMemoryIntegration.from_dict(data)
|
|
|
|
# Verify data preserved
|
|
trades = restored.trade_memory.get_trades_by_symbol("NVDA")
|
|
assert len(trades) == 1
|
|
|
|
profile = restored.risk_memory.get_profile("test")
|
|
assert profile is not None
|
|
assert profile.base_tolerance == RiskTolerance.AGGRESSIVE
|
|
|
|
|
|
# =============================================================================
|
|
# Helper Function Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestCreateMemoryEnhancedPrompt:
|
|
"""Tests for create_memory_enhanced_prompt function."""
|
|
|
|
def test_empty_context_returns_base(self):
|
|
"""Test that empty context returns base prompt unchanged."""
|
|
base_prompt = "Analyze the stock"
|
|
context = MemoryContext()
|
|
|
|
result = create_memory_enhanced_prompt(base_prompt, context)
|
|
|
|
assert result == base_prompt
|
|
|
|
def test_adds_memory_section(self):
|
|
"""Test that memory section is added."""
|
|
base_prompt = "Analyze the stock"
|
|
context = MemoryContext(trade_history="Previous trade: +5%")
|
|
|
|
result = create_memory_enhanced_prompt(base_prompt, context)
|
|
|
|
assert "Analyze the stock" in result
|
|
assert "Memory Context" in result
|
|
assert "Previous trade: +5%" in result
|
|
|
|
def test_respects_context_types(self):
|
|
"""Test filtering by context types."""
|
|
base_prompt = "Analyze"
|
|
context = MemoryContext(
|
|
trade_history="Trade history",
|
|
risk_context="Risk context",
|
|
)
|
|
|
|
result = create_memory_enhanced_prompt(
|
|
base_prompt,
|
|
context,
|
|
[ContextType.TRADE_HISTORY],
|
|
)
|
|
|
|
assert "Trade history" in result
|
|
assert "Risk context" not in result
|
|
|
|
|
|
# =============================================================================
|
|
# Integration Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestAgentMemoryIntegrationWorkflow:
|
|
"""Integration tests for complete memory workflow."""
|
|
|
|
def test_full_trading_workflow(self):
|
|
"""Test a complete trading workflow with memory."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# 1. Set up user profile
|
|
profile = RiskProfile(
|
|
user_id="trader1",
|
|
base_tolerance=RiskTolerance.MODERATE,
|
|
)
|
|
integration.risk_memory.set_profile(profile)
|
|
|
|
# 2. Record some historical trades
|
|
historical_trades = [
|
|
("AAPL", TradeDirection.LONG, 150.0, 160.0),
|
|
("AAPL", TradeDirection.LONG, 155.0, 165.0),
|
|
("AAPL", TradeDirection.LONG, 160.0, 155.0), # Loss
|
|
]
|
|
|
|
for symbol, direction, entry, exit in historical_trades:
|
|
trade = TradeRecord.create(
|
|
symbol=symbol,
|
|
direction=direction,
|
|
entry_price=entry,
|
|
quantity=100,
|
|
)
|
|
trade.close(exit_price=exit)
|
|
integration.record_trade_outcome(
|
|
trade=trade,
|
|
situation_context=f"Trade in {symbol}",
|
|
)
|
|
|
|
# 3. Get analyst context for new analysis
|
|
analyst_context = integration.get_analyst_context(
|
|
symbol="AAPL",
|
|
current_situation="Apple showing momentum",
|
|
analyst_type="momentum",
|
|
)
|
|
|
|
assert len(analyst_context.raw_trades) == 3
|
|
assert "AAPL" in analyst_context.trade_history
|
|
|
|
# 4. Get trader context for decision
|
|
trader_context = integration.get_trader_context(
|
|
symbol="AAPL",
|
|
current_situation="Strong momentum signal",
|
|
proposed_action="buy",
|
|
market_regime=MarketRegime.BULL,
|
|
user_id="trader1",
|
|
)
|
|
|
|
assert "Recommended risk level" in trader_context.risk_context
|
|
|
|
# 5. Get risk manager context
|
|
risk_context = integration.get_risk_manager_context(
|
|
symbol="AAPL",
|
|
proposed_trade="Buy 100 shares",
|
|
position_size=16000,
|
|
market_regime=MarketRegime.BULL,
|
|
user_id="trader1",
|
|
)
|
|
|
|
assert "Win rate" in risk_context.trade_history
|
|
assert "moderate" in risk_context.risk_context.lower()
|
|
|
|
def test_memory_influences_recommendations(self):
|
|
"""Test that memory influences risk recommendations."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# Set up profile
|
|
profile = RiskProfile(user_id="default", base_tolerance=RiskTolerance.MODERATE)
|
|
integration.risk_memory.set_profile(profile)
|
|
|
|
# Record successful high-risk decisions
|
|
for i in range(5):
|
|
decision_id = integration.record_risk_decision(
|
|
category=RiskCategory.POSITION_SIZE,
|
|
risk_level=0.7,
|
|
market_regime=MarketRegime.BULL,
|
|
context=f"Successful trade {i}",
|
|
)
|
|
integration.evaluate_risk_decision(
|
|
decision_id=decision_id,
|
|
outcome="Profitable",
|
|
outcome_score=0.6,
|
|
was_appropriate=True,
|
|
)
|
|
|
|
# Get recommendation - should be influenced by history
|
|
risk_level, explanation = integration.risk_memory.recommend_risk_level(
|
|
category=RiskCategory.POSITION_SIZE,
|
|
market_regime=MarketRegime.BULL,
|
|
context="Similar successful situation",
|
|
)
|
|
|
|
# With successful high-risk history, recommendation should be elevated
|
|
# Base moderate (0.375) + bull adjustment (0.1) = 0.475
|
|
# But with history of successful 0.7 decisions, should be higher
|
|
assert risk_level > 0.5
|
|
|
|
def test_lessons_extracted_from_trades(self):
|
|
"""Test that lessons are extracted from trade patterns."""
|
|
integration = AgentMemoryIntegration()
|
|
|
|
# Record winning trades with short hold times
|
|
for i in range(3):
|
|
trade = TradeRecord.create(
|
|
symbol="TSLA",
|
|
direction=TradeDirection.LONG,
|
|
entry_price=200.0,
|
|
quantity=50,
|
|
)
|
|
# Quick win - close the trade
|
|
trade.close(exit_price=220.0)
|
|
integration.trade_memory.record_trade(trade)
|
|
|
|
# Record losing trades with long hold times
|
|
for i in range(3):
|
|
trade = TradeRecord.create(
|
|
symbol="TSLA",
|
|
direction=TradeDirection.LONG,
|
|
entry_price=200.0,
|
|
quantity=50,
|
|
)
|
|
# Slow loss
|
|
trade.close(exit_price=180.0)
|
|
integration.trade_memory.record_trade(trade)
|
|
|
|
context = integration.get_analyst_context(
|
|
symbol="TSLA",
|
|
current_situation="EV sector",
|
|
analyst_type="general",
|
|
)
|
|
|
|
# Lessons should exist (may be strategy continuation or specific lesson)
|
|
assert context.lessons_learned != ""
|
|
|
|
|
|
class TestContextTypeFiltering:
|
|
"""Tests for context type filtering."""
|
|
|
|
def test_all_types(self):
|
|
"""Test ALL context type includes everything."""
|
|
context = MemoryContext(
|
|
trade_history="History",
|
|
risk_context="Risk",
|
|
similar_situations="Similar",
|
|
lessons_learned="Lessons",
|
|
)
|
|
|
|
result = context.to_prompt_string([ContextType.ALL])
|
|
|
|
assert "History" in result
|
|
assert "Risk" in result
|
|
assert "Similar" in result
|
|
assert "Lessons" in result
|
|
|
|
def test_single_type(self):
|
|
"""Test filtering to single type."""
|
|
context = MemoryContext(
|
|
trade_history="History",
|
|
risk_context="Risk",
|
|
)
|
|
|
|
result = context.to_prompt_string([ContextType.RISK_PROFILE])
|
|
|
|
assert "History" not in result
|
|
assert "Risk" in result
|
|
|
|
def test_empty_sections_not_included(self):
|
|
"""Test that empty sections aren't shown."""
|
|
context = MemoryContext(
|
|
trade_history="History",
|
|
risk_context="", # Empty
|
|
)
|
|
|
|
result = context.to_prompt_string([ContextType.ALL])
|
|
|
|
assert "Recent Trade History" in result
|
|
assert "Risk Profile Context" not in result
|