From d72c214d4dfe4bcf929ce36b8b1e374cf04d2ac0 Mon Sep 17 00:00:00 2001 From: Andrew Kaszubski Date: Fri, 26 Dec 2025 20:17:26 +1100 Subject: [PATCH] feat(memory): add layered memory system with FinMem pattern - Fixes #18 --- tests/unit/memory/test_layered_memory.py | 1030 ++++++++++++++++++++++ tradingagents/memory/__init__.py | 27 + tradingagents/memory/layered_memory.py | 729 +++++++++++++++ 3 files changed, 1786 insertions(+) create mode 100644 tests/unit/memory/test_layered_memory.py create mode 100644 tradingagents/memory/__init__.py create mode 100644 tradingagents/memory/layered_memory.py diff --git a/tests/unit/memory/test_layered_memory.py b/tests/unit/memory/test_layered_memory.py new file mode 100644 index 00000000..9c24ebc9 --- /dev/null +++ b/tests/unit/memory/test_layered_memory.py @@ -0,0 +1,1030 @@ +"""Tests for Issue #18: Layered Memory System implementing FinMem pattern. + +This module tests the layered memory system with three scoring dimensions: +- Recency: Time-based decay +- Relevancy: Semantic similarity +- Importance: Significance weighting +""" + +import pytest +import math +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +from tradingagents.memory.layered_memory import ( + LayeredMemory, + MemoryEntry, + MemoryConfig, + ScoringWeights, + DecayFunction, + ImportanceLevel, + ScoredMemory, +) + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + +@pytest.fixture +def default_config(): + """Default memory configuration.""" + return MemoryConfig() + + +@pytest.fixture +def custom_weights(): + """Custom scoring weights.""" + return ScoringWeights(recency=0.4, relevancy=0.4, importance=0.2) + + +@pytest.fixture +def memory_with_default_config(): + """LayeredMemory instance with default configuration.""" + return LayeredMemory() + + +@pytest.fixture +def sample_entry(): + """Sample memory entry.""" + return MemoryEntry.create( + content="Market crash of 10% in tech sector", + metadata={"recommendation": "Reduce tech exposure"}, + importance=ImportanceLevel.HIGH.value, + tags=["market", "tech", "crash"], + ) + + +@pytest.fixture +def multiple_entries(): + """Multiple memory entries with different timestamps and importance.""" + now = datetime.now() + return [ + MemoryEntry( + id="entry-1", + content="Tech sector volatility increased significantly", + metadata={"recommendation": "Reduce exposure"}, + timestamp=now - timedelta(days=1), + importance=ImportanceLevel.HIGH.value, + tags=["tech", "volatility"], + ), + MemoryEntry( + id="entry-2", + content="Federal Reserve announced rate hike", + metadata={"recommendation": "Consider defensive positions"}, + timestamp=now - timedelta(days=7), + importance=ImportanceLevel.CRITICAL.value, + tags=["fed", "rates"], + ), + MemoryEntry( + id="entry-3", + content="Minor retail earnings miss", + metadata={"recommendation": "Monitor retail sector"}, + timestamp=now - timedelta(days=30), + importance=ImportanceLevel.LOW.value, + tags=["retail", "earnings"], + ), + MemoryEntry( + id="entry-4", + content="Normal trading day with slight gains", + metadata={"recommendation": "Hold positions"}, + timestamp=now - timedelta(hours=6), + importance=ImportanceLevel.MINIMAL.value, + tags=["normal"], + ), + ] + + +# ============================================================================= +# ScoringWeights Tests +# ============================================================================= + +class TestScoringWeights: + """Tests for the ScoringWeights class.""" + + def test_default_weights(self): + """Default weights should be 0.3, 0.5, 0.2.""" + weights = ScoringWeights() + assert weights.recency == 0.3 + assert weights.relevancy == 0.5 + assert weights.importance == 0.2 + + def test_custom_weights(self): + """Custom weights should be stored correctly.""" + weights = ScoringWeights(recency=0.4, relevancy=0.4, importance=0.2) + assert weights.recency == 0.4 + assert weights.relevancy == 0.4 + assert weights.importance == 0.2 + + def test_total_property(self): + """Total should return sum of all weights.""" + weights = ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2) + assert weights.total == 1.0 + + def test_normalized_weights(self): + """Normalized weights should sum to 1.0.""" + weights = ScoringWeights(recency=2.0, relevancy=4.0, importance=4.0) + normalized = weights.normalized() + assert normalized.recency == 0.2 + assert normalized.relevancy == 0.4 + assert normalized.importance == 0.4 + assert abs(normalized.total - 1.0) < 1e-10 + + def test_normalized_zero_weights(self): + """Zero weights should normalize to equal weights.""" + weights = ScoringWeights(recency=0, relevancy=0, importance=0) + normalized = weights.normalized() + assert abs(normalized.recency - 1/3) < 1e-10 + assert abs(normalized.relevancy - 1/3) < 1e-10 + assert abs(normalized.importance - 1/3) < 1e-10 + + def test_negative_weights_raise_error(self): + """Negative weights should raise ValueError.""" + with pytest.raises(ValueError): + ScoringWeights(recency=-0.1, relevancy=0.5, importance=0.2) + + +# ============================================================================= +# MemoryConfig Tests +# ============================================================================= + +class TestMemoryConfig: + """Tests for the MemoryConfig class.""" + + def test_default_config(self): + """Default config should have sensible defaults.""" + config = MemoryConfig() + assert config.weights.recency == 0.3 + assert config.decay_function == DecayFunction.EXPONENTIAL + assert config.decay_half_life_days == 7 + assert config.max_age_days == 365 + assert config.default_top_k == 5 + + def test_custom_config(self): + """Custom config values should be stored.""" + config = MemoryConfig( + weights=ScoringWeights(0.4, 0.4, 0.2), + decay_function=DecayFunction.LINEAR, + decay_half_life_days=14, + max_age_days=180, + ) + assert config.weights.recency == 0.4 + assert config.decay_function == DecayFunction.LINEAR + assert config.decay_half_life_days == 14 + assert config.max_age_days == 180 + + +# ============================================================================= +# ImportanceLevel Tests +# ============================================================================= + +class TestImportanceLevel: + """Tests for the ImportanceLevel enum.""" + + def test_critical_value(self): + """CRITICAL should be 1.0.""" + assert ImportanceLevel.CRITICAL.value == 1.0 + + def test_high_value(self): + """HIGH should be 0.8.""" + assert ImportanceLevel.HIGH.value == 0.8 + + def test_medium_value(self): + """MEDIUM should be 0.5.""" + assert ImportanceLevel.MEDIUM.value == 0.5 + + def test_low_value(self): + """LOW should be 0.2.""" + assert ImportanceLevel.LOW.value == 0.2 + + def test_minimal_value(self): + """MINIMAL should be 0.1.""" + assert ImportanceLevel.MINIMAL.value == 0.1 + + def test_ordering(self): + """Importance levels should be ordered correctly.""" + assert ImportanceLevel.CRITICAL.value > ImportanceLevel.HIGH.value + assert ImportanceLevel.HIGH.value > ImportanceLevel.MEDIUM.value + assert ImportanceLevel.MEDIUM.value > ImportanceLevel.LOW.value + assert ImportanceLevel.LOW.value > ImportanceLevel.MINIMAL.value + + +# ============================================================================= +# MemoryEntry Tests +# ============================================================================= + +class TestMemoryEntry: + """Tests for the MemoryEntry class.""" + + def test_create_entry(self): + """Create should generate a valid entry.""" + entry = MemoryEntry.create( + content="Test content", + metadata={"key": "value"}, + importance=0.7, + tags=["test"], + ) + assert entry.content == "Test content" + assert entry.metadata == {"key": "value"} + assert entry.importance == 0.7 + assert entry.tags == ["test"] + assert entry.id is not None + assert entry.timestamp is not None + + def test_create_default_values(self): + """Create with defaults should work.""" + entry = MemoryEntry.create(content="Test") + assert entry.metadata == {} + assert entry.importance == 0.5 + assert entry.tags == [] + + def test_importance_validation(self): + """Importance outside [0, 1] should raise error.""" + with pytest.raises(ValueError): + MemoryEntry.create(content="Test", importance=1.5) + + with pytest.raises(ValueError): + MemoryEntry.create(content="Test", importance=-0.1) + + def test_age_days(self): + """Age days should calculate correctly.""" + now = datetime.now() + entry = MemoryEntry.create(content="Test") + entry.timestamp = now - timedelta(days=5) + + age = entry.age_days(now) + assert abs(age - 5.0) < 0.01 + + def test_age_days_partial(self): + """Age days should handle partial days.""" + now = datetime.now() + entry = MemoryEntry.create(content="Test") + entry.timestamp = now - timedelta(hours=12) + + age = entry.age_days(now) + assert abs(age - 0.5) < 0.01 + + def test_to_dict(self): + """To dict should serialize correctly.""" + entry = MemoryEntry.create( + content="Test", + metadata={"key": "value"}, + importance=0.8, + tags=["tag1"], + ) + data = entry.to_dict() + + assert data["content"] == "Test" + assert data["metadata"] == {"key": "value"} + assert data["importance"] == 0.8 + assert data["tags"] == ["tag1"] + assert "id" in data + assert "timestamp" in data + + def test_from_dict(self): + """From dict should deserialize correctly.""" + data = { + "id": "test-id", + "content": "Test content", + "metadata": {"key": "value"}, + "timestamp": "2024-01-15T10:30:00", + "importance": 0.7, + "tags": ["tag1", "tag2"], + } + entry = MemoryEntry.from_dict(data) + + assert entry.id == "test-id" + assert entry.content == "Test content" + assert entry.metadata == {"key": "value"} + assert entry.importance == 0.7 + assert entry.tags == ["tag1", "tag2"] + + +# ============================================================================= +# DecayFunction Tests +# ============================================================================= + +class TestDecayFunction: + """Tests for different decay functions.""" + + def test_exponential_decay(self): + """Exponential decay should decrease over time.""" + config = MemoryConfig( + decay_function=DecayFunction.EXPONENTIAL, + decay_lambda=0.1, + ) + memory = LayeredMemory(config=config) + + now = datetime.now() + entry_recent = MemoryEntry.create(content="Recent") + entry_recent.timestamp = now - timedelta(days=1) + + entry_old = MemoryEntry.create(content="Old") + entry_old.timestamp = now - timedelta(days=30) + + score_recent = memory._calculate_recency_score(entry_recent, now) + score_old = memory._calculate_recency_score(entry_old, now) + + assert score_recent > score_old + assert score_recent <= 1.0 + assert score_old >= config.decay_floor + + def test_linear_decay(self): + """Linear decay should decrease linearly.""" + config = MemoryConfig( + decay_function=DecayFunction.LINEAR, + max_age_days=100, + decay_floor=0.0, + ) + memory = LayeredMemory(config=config) + + now = datetime.now() + entry = MemoryEntry.create(content="Test") + entry.timestamp = now - timedelta(days=50) + + score = memory._calculate_recency_score(entry, now) + # At 50 days of 100 max, linear decay should be ~0.5 + assert abs(score - 0.5) < 0.01 + + def test_step_decay(self): + """Step decay should drop after half-life.""" + config = MemoryConfig( + decay_function=DecayFunction.STEP, + decay_half_life_days=7, + decay_floor=0.2, + ) + memory = LayeredMemory(config=config) + + now = datetime.now() + + entry_before = MemoryEntry.create(content="Before") + entry_before.timestamp = now - timedelta(days=5) + + entry_after = MemoryEntry.create(content="After") + entry_after.timestamp = now - timedelta(days=10) + + score_before = memory._calculate_recency_score(entry_before, now) + score_after = memory._calculate_recency_score(entry_after, now) + + assert score_before == 1.0 + assert score_after == 0.2 + + def test_power_decay(self): + """Power decay should follow 1/(1+t)^alpha.""" + config = MemoryConfig( + decay_function=DecayFunction.POWER, + decay_lambda=0.5, # alpha + decay_floor=0.0, + ) + memory = LayeredMemory(config=config) + + now = datetime.now() + entry = MemoryEntry.create(content="Test") + entry.timestamp = now - timedelta(days=3) + + score = memory._calculate_recency_score(entry, now) + # At 3 days with alpha=0.5: 1/(1+3)^0.5 = 1/2 = 0.5 + expected = 1 / ((1 + 3) ** 0.5) + assert abs(score - expected) < 0.01 + + def test_decay_floor(self): + """Decay should never go below floor.""" + config = MemoryConfig( + decay_function=DecayFunction.EXPONENTIAL, + decay_lambda=1.0, # Fast decay + decay_floor=0.1, + ) + memory = LayeredMemory(config=config) + + now = datetime.now() + entry = MemoryEntry.create(content="Very old") + entry.timestamp = now - timedelta(days=100) + + score = memory._calculate_recency_score(entry, now) + assert score >= 0.1 + + +# ============================================================================= +# Relevancy Scoring Tests +# ============================================================================= + +class TestRelevancyScoring: + """Tests for relevancy scoring.""" + + def test_word_overlap_identical(self): + """Identical texts should have similarity 1.0.""" + memory = LayeredMemory() + score = memory._word_overlap_similarity( + "tech sector crash", + "tech sector crash", + ) + assert score == 1.0 + + def test_word_overlap_partial(self): + """Partial overlap should have intermediate similarity.""" + memory = LayeredMemory() + score = memory._word_overlap_similarity( + "tech sector crash today", + "tech sector rally tomorrow", + ) + # Common: tech, sector (2) + # Union: tech, sector, crash, today, rally, tomorrow (6) + assert abs(score - 2/6) < 0.01 + + def test_word_overlap_none(self): + """No overlap should have similarity 0.""" + memory = LayeredMemory() + score = memory._word_overlap_similarity( + "apple banana cherry", + "dog elephant fox", + ) + assert score == 0.0 + + def test_word_overlap_case_insensitive(self): + """Word overlap should be case insensitive.""" + memory = LayeredMemory() + score = memory._word_overlap_similarity( + "TECH SECTOR", + "tech sector", + ) + assert score == 1.0 + + def test_cosine_similarity_identical(self): + """Identical vectors should have cosine similarity 1.0.""" + memory = LayeredMemory() + vec = [1.0, 2.0, 3.0] + score = memory._cosine_similarity(vec, vec) + # Normalized to [0, 1]: (1 + 1) / 2 = 1.0 + assert abs(score - 1.0) < 0.01 + + def test_cosine_similarity_orthogonal(self): + """Orthogonal vectors should have cosine similarity 0.5 (normalized).""" + memory = LayeredMemory() + vec1 = [1.0, 0.0, 0.0] + vec2 = [0.0, 1.0, 0.0] + score = memory._cosine_similarity(vec1, vec2) + # cos(90°) = 0, normalized to [0, 1]: (0 + 1) / 2 = 0.5 + assert abs(score - 0.5) < 0.01 + + def test_cosine_similarity_opposite(self): + """Opposite vectors should have cosine similarity 0 (normalized).""" + memory = LayeredMemory() + vec1 = [1.0, 0.0, 0.0] + vec2 = [-1.0, 0.0, 0.0] + score = memory._cosine_similarity(vec1, vec2) + # cos(180°) = -1, normalized to [0, 1]: (-1 + 1) / 2 = 0 + assert abs(score - 0.0) < 0.01 + + +# ============================================================================= +# Importance Scoring Tests +# ============================================================================= + +class TestImportanceScoring: + """Tests for importance scoring.""" + + def test_auto_importance_critical(self): + """Returns >= 10% should be CRITICAL.""" + config = MemoryConfig(auto_importance=True) + memory = LayeredMemory(config=config) + + entry = MemoryEntry.create( + content="Major market move", + metadata={"returns": 0.15}, + ) + memory.add(entry) + + score = memory._calculate_importance_score(entry) + assert score == ImportanceLevel.CRITICAL.value + + def test_auto_importance_high(self): + """Returns >= 5% should be HIGH.""" + config = MemoryConfig(auto_importance=True) + memory = LayeredMemory(config=config) + + entry = MemoryEntry.create( + content="Significant move", + metadata={"returns": 0.07}, + ) + memory.add(entry) + + score = memory._calculate_importance_score(entry) + assert score == ImportanceLevel.HIGH.value + + def test_auto_importance_medium(self): + """Returns >= 1% should be MEDIUM.""" + config = MemoryConfig(auto_importance=True) + memory = LayeredMemory(config=config) + + entry = MemoryEntry.create( + content="Normal move", + metadata={"returns": 0.03}, + ) + memory.add(entry) + + score = memory._calculate_importance_score(entry) + assert score == ImportanceLevel.MEDIUM.value + + def test_auto_importance_low(self): + """Returns < 1% should be LOW.""" + config = MemoryConfig(auto_importance=True) + memory = LayeredMemory(config=config) + + entry = MemoryEntry.create( + content="Minor move", + metadata={"returns": 0.005}, + ) + memory.add(entry) + + score = memory._calculate_importance_score(entry) + assert score == ImportanceLevel.LOW.value + + def test_auto_importance_negative(self): + """Negative returns should use absolute value.""" + config = MemoryConfig(auto_importance=True) + memory = LayeredMemory(config=config) + + entry = MemoryEntry.create( + content="Market crash", + metadata={"returns": -0.12}, + ) + memory.add(entry) + + score = memory._calculate_importance_score(entry) + assert score == ImportanceLevel.CRITICAL.value + + def test_manual_importance(self): + """Manual importance should be used when auto is disabled.""" + config = MemoryConfig(auto_importance=False) + memory = LayeredMemory(config=config) + + entry = MemoryEntry.create( + content="Test", + metadata={"returns": 0.15}, # Would be CRITICAL with auto + importance=0.3, # Manual LOW + ) + memory.add(entry) + + score = memory._calculate_importance_score(entry) + assert score == 0.3 + + +# ============================================================================= +# Combined Scoring Tests +# ============================================================================= + +class TestCombinedScoring: + """Tests for combined scoring.""" + + def test_combined_score_calculation(self): + """Combined score should use weighted sum.""" + config = MemoryConfig( + weights=ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2) + ) + memory = LayeredMemory(config=config) + + # Score = 0.3 * 0.8 + 0.5 * 0.6 + 0.2 * 1.0 = 0.24 + 0.3 + 0.2 = 0.74 + score = memory._calculate_combined_score( + recency=0.8, + relevancy=0.6, + importance=1.0, + ) + assert abs(score - 0.74) < 0.01 + + def test_combined_score_normalized(self): + """Combined score with non-normalized weights.""" + config = MemoryConfig( + weights=ScoringWeights(recency=1.0, relevancy=2.0, importance=2.0) + ) + memory = LayeredMemory(config=config) + + # Normalized: 0.2, 0.4, 0.4 + # Score = 0.2 * 0.5 + 0.4 * 0.5 + 0.4 * 0.5 = 0.1 + 0.2 + 0.2 = 0.5 + score = memory._calculate_combined_score( + recency=0.5, + relevancy=0.5, + importance=0.5, + ) + assert abs(score - 0.5) < 0.01 + + +# ============================================================================= +# LayeredMemory CRUD Tests +# ============================================================================= + +class TestLayeredMemoryCRUD: + """Tests for LayeredMemory CRUD operations.""" + + def test_add_entry(self, memory_with_default_config, sample_entry): + """Add should store entry and return ID.""" + memory_id = memory_with_default_config.add(sample_entry) + assert memory_id == sample_entry.id + assert memory_with_default_config.count() == 1 + + def test_add_batch(self, memory_with_default_config, multiple_entries): + """Add batch should store all entries.""" + ids = memory_with_default_config.add_batch(multiple_entries) + assert len(ids) == len(multiple_entries) + assert memory_with_default_config.count() == len(multiple_entries) + + def test_get_entry(self, memory_with_default_config, sample_entry): + """Get should return stored entry.""" + memory_with_default_config.add(sample_entry) + retrieved = memory_with_default_config.get(sample_entry.id) + assert retrieved is not None + assert retrieved.content == sample_entry.content + + def test_get_nonexistent(self, memory_with_default_config): + """Get nonexistent ID should return None.""" + result = memory_with_default_config.get("nonexistent-id") + assert result is None + + def test_remove_entry(self, memory_with_default_config, sample_entry): + """Remove should delete entry.""" + memory_with_default_config.add(sample_entry) + result = memory_with_default_config.remove(sample_entry.id) + assert result is True + assert memory_with_default_config.count() == 0 + + def test_remove_nonexistent(self, memory_with_default_config): + """Remove nonexistent ID should return False.""" + result = memory_with_default_config.remove("nonexistent-id") + assert result is False + + def test_clear(self, memory_with_default_config, multiple_entries): + """Clear should remove all entries.""" + memory_with_default_config.add_batch(multiple_entries) + count = memory_with_default_config.clear() + assert count == len(multiple_entries) + assert memory_with_default_config.count() == 0 + + def test_update_importance(self, memory_with_default_config, sample_entry): + """Update importance should modify entry.""" + memory_with_default_config.add(sample_entry) + result = memory_with_default_config.update_importance(sample_entry.id, 0.9) + assert result is True + entry = memory_with_default_config.get(sample_entry.id) + assert entry.importance == 0.9 + + def test_update_importance_invalid(self, memory_with_default_config, sample_entry): + """Update importance with invalid value should raise.""" + memory_with_default_config.add(sample_entry) + with pytest.raises(ValueError): + memory_with_default_config.update_importance(sample_entry.id, 1.5) + + +# ============================================================================= +# Retrieval Tests +# ============================================================================= + +class TestRetrieval: + """Tests for memory retrieval.""" + + def test_retrieve_empty(self, memory_with_default_config): + """Retrieve from empty memory should return empty list.""" + results = memory_with_default_config.retrieve("test query") + assert results == [] + + def test_retrieve_basic(self, memory_with_default_config, multiple_entries): + """Basic retrieval should return scored memories.""" + memory_with_default_config.add_batch(multiple_entries) + results = memory_with_default_config.retrieve( + query="tech sector volatility", + top_k=2, + ) + assert len(results) == 2 + assert all(isinstance(r, ScoredMemory) for r in results) + # First result should have highest combined score + assert results[0].combined_score >= results[1].combined_score + + def test_retrieve_respects_top_k(self, memory_with_default_config, multiple_entries): + """Retrieve should respect top_k limit.""" + memory_with_default_config.add_batch(multiple_entries) + results = memory_with_default_config.retrieve("query", top_k=2) + assert len(results) <= 2 + + def test_retrieve_by_tags(self, memory_with_default_config, multiple_entries): + """Retrieve should filter by tags.""" + memory_with_default_config.add_batch(multiple_entries) + results = memory_with_default_config.retrieve( + query="market", + tags=["tech"], + ) + # Only entries with "tech" tag + for r in results: + assert "tech" in r.entry.tags + + def test_retrieve_min_score(self, memory_with_default_config, multiple_entries): + """Retrieve should filter by min score.""" + memory_with_default_config.add_batch(multiple_entries) + results = memory_with_default_config.retrieve( + query="tech", + min_score=0.5, + ) + for r in results: + assert r.combined_score >= 0.5 + + def test_retrieve_by_recency(self, memory_with_default_config, multiple_entries): + """Retrieve by recency should sort by timestamp.""" + memory_with_default_config.add_batch(multiple_entries) + results = memory_with_default_config.retrieve_by_recency(top_k=3) + assert len(results) == 3 + # Should be sorted by timestamp descending + for i in range(len(results) - 1): + assert results[i].timestamp >= results[i + 1].timestamp + + def test_retrieve_by_importance(self, memory_with_default_config, multiple_entries): + """Retrieve by importance should sort by importance.""" + memory_with_default_config.add_batch(multiple_entries) + results = memory_with_default_config.retrieve_by_importance( + top_k=3, + min_importance=0.1, + ) + # Should be sorted by importance descending + for i in range(len(results) - 1): + assert results[i].importance >= results[i + 1].importance + + +# ============================================================================= +# Scoring Entry Tests +# ============================================================================= + +class TestScoreEntry: + """Tests for score_entry method.""" + + def test_score_entry_returns_all_scores(self, memory_with_default_config, sample_entry): + """Score entry should return all scoring dimensions.""" + memory_with_default_config.add(sample_entry) + scored = memory_with_default_config.score_entry( + sample_entry, + query="tech sector crash", + ) + + assert isinstance(scored, ScoredMemory) + assert 0 <= scored.recency_score <= 1 + assert 0 <= scored.relevancy_score <= 1 + assert 0 <= scored.importance_score <= 1 + assert 0 <= scored.combined_score <= 1 + + +# ============================================================================= +# Statistics Tests +# ============================================================================= + +class TestStatistics: + """Tests for get_statistics method.""" + + def test_statistics_empty(self, memory_with_default_config): + """Statistics for empty memory should return zeros.""" + stats = memory_with_default_config.get_statistics() + assert stats["count"] == 0 + assert stats["oldest"] is None + assert stats["newest"] is None + + def test_statistics_with_data(self, memory_with_default_config, multiple_entries): + """Statistics should reflect stored data.""" + memory_with_default_config.add_batch(multiple_entries) + stats = memory_with_default_config.get_statistics() + + assert stats["count"] == len(multiple_entries) + assert stats["oldest"] is not None + assert stats["newest"] is not None + assert "importance_distribution" in stats + + +# ============================================================================= +# Serialization Tests +# ============================================================================= + +class TestSerialization: + """Tests for serialization and deserialization.""" + + def test_to_dict(self, memory_with_default_config, multiple_entries): + """To dict should serialize memory.""" + memory_with_default_config.add_batch(multiple_entries) + data = memory_with_default_config.to_dict() + + assert "memories" in data + assert "config" in data + assert len(data["memories"]) == len(multiple_entries) + + def test_from_dict(self, memory_with_default_config, multiple_entries): + """From dict should deserialize memory.""" + memory_with_default_config.add_batch(multiple_entries) + data = memory_with_default_config.to_dict() + + restored = LayeredMemory.from_dict(data) + assert restored.count() == len(multiple_entries) + + def test_roundtrip(self, memory_with_default_config, sample_entry): + """Roundtrip serialization should preserve data.""" + memory_with_default_config.add(sample_entry) + data = memory_with_default_config.to_dict() + restored = LayeredMemory.from_dict(data) + + original_entry = memory_with_default_config.get(sample_entry.id) + restored_entry = restored.get(sample_entry.id) + + assert restored_entry is not None + assert restored_entry.content == original_entry.content + assert restored_entry.importance == original_entry.importance + + +# ============================================================================= +# Embedding Function Tests +# ============================================================================= + +class TestEmbeddingFunction: + """Tests for custom embedding function integration.""" + + def test_with_embedding_function(self): + """Memory with embedding function should use it.""" + def mock_embedding(text: str) -> list: + # Simple mock: return length-based vector + return [len(text) / 100, len(text.split()) / 10, 0.5] + + memory = LayeredMemory(embedding_function=mock_embedding) + entry = MemoryEntry.create(content="Test content for embedding") + memory.add(entry) + + # Entry should have embedding + assert entry.embedding is not None + assert len(entry.embedding) == 3 + + def test_embedding_function_failure(self): + """Failed embedding should not crash.""" + def failing_embedding(text: str) -> list: + raise RuntimeError("Embedding failed") + + memory = LayeredMemory(embedding_function=failing_embedding) + entry = MemoryEntry.create(content="Test") + + # Should not raise + memory_id = memory.add(entry) + assert memory_id is not None + + +# ============================================================================= +# Edge Cases +# ============================================================================= + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_very_old_memory(self): + """Very old memories should get minimum recency score.""" + config = MemoryConfig(max_age_days=365, decay_floor=0.05) + memory = LayeredMemory(config=config) + + now = datetime.now() + entry = MemoryEntry.create(content="Ancient history") + entry.timestamp = now - timedelta(days=500) + memory.add(entry) + + score = memory._calculate_recency_score(entry, now) + assert score == 0.05 + + def test_future_timestamp(self): + """Future timestamps should have maximum recency.""" + memory = LayeredMemory() + now = datetime.now() + entry = MemoryEntry.create(content="Future") + entry.timestamp = now + timedelta(days=1) + memory.add(entry) + + score = memory._calculate_recency_score(entry, now) + # Negative age should give high recency + assert score >= 0.9 + + def test_empty_content(self): + """Empty content should handle gracefully.""" + memory = LayeredMemory() + score = memory._word_overlap_similarity("", "test") + assert score == 0.0 + + def test_unicode_content(self): + """Unicode content should work correctly.""" + memory = LayeredMemory() + entry = MemoryEntry.create( + content="市场崩盘 📉 Tech crash", + metadata={"language": "mixed"}, + ) + memory.add(entry) + + results = memory.retrieve(query="Tech crash") + assert len(results) >= 0 # Should not crash + + +# ============================================================================= +# Integration Tests +# ============================================================================= + +class TestIntegration: + """Integration tests for the full workflow.""" + + def test_full_workflow(self): + """Test complete memory workflow.""" + # Configure with custom weights + config = MemoryConfig( + weights=ScoringWeights(recency=0.3, relevancy=0.5, importance=0.2), + decay_function=DecayFunction.EXPONENTIAL, + decay_half_life_days=7, + ) + memory = LayeredMemory(config=config) + + # Add memories with different characteristics + now = datetime.now() + + entries = [ + MemoryEntry( + id="recent-critical", + content="Tech sector crash of 15%", + metadata={"recommendation": "Reduce exposure immediately"}, + timestamp=now - timedelta(hours=1), + importance=ImportanceLevel.CRITICAL.value, + tags=["tech", "crash"], + ), + MemoryEntry( + id="old-critical", + content="Previous tech bubble burst", + metadata={"recommendation": "Historical: went to cash"}, + timestamp=now - timedelta(days=365), + importance=ImportanceLevel.CRITICAL.value, + tags=["tech", "historical"], + ), + MemoryEntry( + id="recent-low", + content="Minor fluctuation in retail", + metadata={"recommendation": "Hold positions"}, + timestamp=now - timedelta(hours=2), + importance=ImportanceLevel.LOW.value, + tags=["retail"], + ), + ] + + memory.add_batch(entries) + + # Query for tech-related memories + results = memory.retrieve( + query="tech sector volatility crash", + top_k=3, + ) + + assert len(results) == 3 + + # Recent + critical + relevant should be first + assert results[0].entry.id == "recent-critical" + + # Old but critical and relevant should beat recent but low importance + # (depends on exact weights and scores) + + # All scores should be valid + for r in results: + assert 0 <= r.recency_score <= 1 + assert 0 <= r.relevancy_score <= 1 + assert 0 <= r.importance_score <= 1 + assert 0 <= r.combined_score <= 1 + + def test_learning_from_trades(self): + """Simulate learning from trade outcomes.""" + memory = LayeredMemory() + + # Add past trade memories with outcomes + trades = [ + { + "content": "Bought AAPL after earnings beat with bullish guidance", + "returns": 0.08, + "recommendation": "Good entry on earnings beat", + }, + { + "content": "Shorted TSLA on production delays", + "returns": -0.05, + "recommendation": "Avoid shorting high-momentum stocks", + }, + { + "content": "Bought SPY on Fed pivot signal", + "returns": 0.12, + "recommendation": "Fed pivots are reliable entry points", + }, + ] + + for trade in trades: + entry = MemoryEntry.create( + content=trade["content"], + metadata={ + "recommendation": trade["recommendation"], + "returns": trade["returns"], + }, + ) + memory.add(entry) + + # Query for similar situation + results = memory.retrieve( + query="considering buying tech stock after earnings", + top_k=2, + ) + + assert len(results) >= 1 + # Should find relevant trade memories + for r in results: + assert r.entry.metadata.get("recommendation") is not None diff --git a/tradingagents/memory/__init__.py b/tradingagents/memory/__init__.py new file mode 100644 index 00000000..7bff2785 --- /dev/null +++ b/tradingagents/memory/__init__.py @@ -0,0 +1,27 @@ +"""Memory module implementing the FinMem pattern for TradingAgents. + +This module provides a layered memory system with three scoring dimensions: +- Recency: Time-based decay for more recent memories +- Relevancy: Semantic similarity to current context +- Importance: Significance weighting for impactful events + +Issue #18: Layered memory - recency, relevancy, importance scoring +""" + +from .layered_memory import ( + LayeredMemory, + MemoryEntry, + MemoryConfig, + ScoringWeights, + DecayFunction, + ImportanceLevel, +) + +__all__ = [ + "LayeredMemory", + "MemoryEntry", + "MemoryConfig", + "ScoringWeights", + "DecayFunction", + "ImportanceLevel", +] diff --git a/tradingagents/memory/layered_memory.py b/tradingagents/memory/layered_memory.py new file mode 100644 index 00000000..389339c8 --- /dev/null +++ b/tradingagents/memory/layered_memory.py @@ -0,0 +1,729 @@ +"""Layered Memory System implementing the FinMem pattern. + +The FinMem pattern uses three scoring dimensions for memory retrieval: +1. Recency Score: Time-based decay - more recent memories are weighted higher +2. Relevancy Score: Semantic similarity - how relevant is this memory to the query +3. Importance Score: Significance weighting - important events are weighted higher + +Final retrieval score: score = w_recency * recency + w_relevancy * relevancy + w_importance * importance + +Issue #18: [MEM-17] Layered memory - recency, relevancy, importance scoring +""" + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from enum import Enum +from typing import Dict, List, Optional, Any, Callable +import math +import uuid + + +class DecayFunction(Enum): + """Decay function types for recency scoring.""" + EXPONENTIAL = "exponential" # e^(-lambda * t) + LINEAR = "linear" # max(0, 1 - t/T) + STEP = "step" # 1 if t < T else decay_floor + POWER = "power" # 1 / (1 + t)^alpha + + +class ImportanceLevel(Enum): + """Predefined importance levels for common trading events.""" + CRITICAL = 1.0 # Major market events, circuit breakers, >10% moves + HIGH = 0.8 # Significant gains/losses (>5%), earnings surprises + MEDIUM = 0.5 # Normal trading, moderate moves (1-5%) + LOW = 0.2 # Minor events, small moves (<1%) + MINIMAL = 0.1 # Routine, no significant impact + + +@dataclass +class ScoringWeights: + """Weights for combining the three scoring dimensions. + + The weights determine how much each factor contributes to the final score. + Weights should typically sum to 1.0 but can be adjusted for emphasis. + """ + recency: float = 0.3 + relevancy: float = 0.5 + importance: float = 0.2 + + def __post_init__(self): + """Validate weights are non-negative.""" + if self.recency < 0 or self.relevancy < 0 or self.importance < 0: + raise ValueError("All weights must be non-negative") + + @property + def total(self) -> float: + """Sum of all weights.""" + return self.recency + self.relevancy + self.importance + + def normalized(self) -> "ScoringWeights": + """Return normalized weights that sum to 1.0.""" + total = self.total + if total == 0: + return ScoringWeights(recency=1/3, relevancy=1/3, importance=1/3) + return ScoringWeights( + recency=self.recency / total, + relevancy=self.relevancy / total, + importance=self.importance / total, + ) + + +@dataclass +class MemoryConfig: + """Configuration for the LayeredMemory system.""" + + # Scoring weights + weights: ScoringWeights = field(default_factory=ScoringWeights) + + # Recency configuration + decay_function: DecayFunction = DecayFunction.EXPONENTIAL + decay_lambda: float = 0.1 # For exponential: e^(-lambda * days) + decay_half_life_days: int = 7 # Alternative: half-life in days + decay_floor: float = 0.1 # Minimum recency score + max_age_days: int = 365 # Maximum age to consider + + # Relevancy configuration + min_relevancy_threshold: float = 0.0 # Minimum similarity to include + normalize_relevancy: bool = True # Normalize to [0, 1] + + # Importance configuration + auto_importance: bool = True # Automatically calculate importance + return_threshold_high: float = 0.05 # >5% return = HIGH importance + return_threshold_critical: float = 0.10 # >10% return = CRITICAL + + # Retrieval configuration + default_top_k: int = 5 + score_threshold: float = 0.0 # Minimum combined score to return + + +@dataclass +class MemoryEntry: + """A single memory entry with all scoring dimensions. + + Attributes: + id: Unique identifier + content: The memory content (situation description) + metadata: Additional metadata (recommendations, context, etc.) + timestamp: When the memory was created + importance: Importance score [0, 1] + embedding: Vector embedding (if pre-computed) + tags: Optional tags for filtering + """ + id: str + content: str + metadata: Dict[str, Any] + timestamp: datetime + importance: float = 0.5 # Default to MEDIUM + embedding: Optional[List[float]] = None + tags: List[str] = field(default_factory=list) + + def __post_init__(self): + """Validate importance is in valid range.""" + if not 0.0 <= self.importance <= 1.0: + raise ValueError(f"Importance must be between 0 and 1, got {self.importance}") + + @classmethod + def create( + cls, + content: str, + metadata: Optional[Dict[str, Any]] = None, + importance: float = 0.5, + tags: Optional[List[str]] = None, + timestamp: Optional[datetime] = None, + ) -> "MemoryEntry": + """Factory method to create a new memory entry.""" + return cls( + id=str(uuid.uuid4()), + content=content, + metadata=metadata or {}, + timestamp=timestamp or datetime.now(), + importance=importance, + tags=tags or [], + ) + + def age_days(self, reference_time: Optional[datetime] = None) -> float: + """Calculate age in days from reference time (default: now).""" + ref = reference_time or datetime.now() + delta = ref - self.timestamp + return delta.total_seconds() / 86400 # Convert seconds to days + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "id": self.id, + "content": self.content, + "metadata": self.metadata, + "timestamp": self.timestamp.isoformat(), + "importance": self.importance, + "tags": self.tags, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MemoryEntry": + """Create from dictionary.""" + return cls( + id=data["id"], + content=data["content"], + metadata=data.get("metadata", {}), + timestamp=datetime.fromisoformat(data["timestamp"]), + importance=data.get("importance", 0.5), + tags=data.get("tags", []), + ) + + +@dataclass +class ScoredMemory: + """A memory entry with computed scores.""" + entry: MemoryEntry + recency_score: float + relevancy_score: float + importance_score: float + combined_score: float + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "entry": self.entry.to_dict(), + "recency_score": self.recency_score, + "relevancy_score": self.relevancy_score, + "importance_score": self.importance_score, + "combined_score": self.combined_score, + } + + +class LayeredMemory: + """Layered Memory System implementing the FinMem pattern. + + This class provides memory storage and retrieval with three-dimensional + scoring based on recency, relevancy, and importance. + + Example: + >>> config = MemoryConfig(weights=ScoringWeights(0.3, 0.5, 0.2)) + >>> memory = LayeredMemory(config=config) + >>> + >>> # Add a memory + >>> entry = MemoryEntry.create( + ... content="Market crash of 10% in tech sector", + ... metadata={"recommendation": "Reduce exposure to tech stocks"}, + ... importance=ImportanceLevel.CRITICAL.value, + ... ) + >>> memory.add(entry) + >>> + >>> # Retrieve relevant memories + >>> results = memory.retrieve( + ... query="Tech sector volatility increasing", + ... top_k=5, + ... ) + """ + + def __init__( + self, + config: Optional[MemoryConfig] = None, + embedding_function: Optional[Callable[[str], List[float]]] = None, + ): + """Initialize the layered memory system. + + Args: + config: Memory configuration (uses defaults if not provided) + embedding_function: Function to compute embeddings for text. + If not provided, uses simple word overlap. + """ + self.config = config or MemoryConfig() + self.embedding_function = embedding_function + self._memories: Dict[str, MemoryEntry] = {} + self._embeddings: Dict[str, List[float]] = {} + + def add(self, entry: MemoryEntry) -> str: + """Add a memory entry. + + Args: + entry: The memory entry to add + + Returns: + The ID of the added entry + """ + self._memories[entry.id] = entry + + # Compute and cache embedding if we have an embedding function + if self.embedding_function is not None: + try: + embedding = self.embedding_function(entry.content) + self._embeddings[entry.id] = embedding + entry.embedding = embedding + except Exception: + pass # Silently fail if embedding computation fails + + return entry.id + + def add_batch(self, entries: List[MemoryEntry]) -> List[str]: + """Add multiple memory entries. + + Args: + entries: List of memory entries to add + + Returns: + List of IDs of added entries + """ + return [self.add(entry) for entry in entries] + + def get(self, memory_id: str) -> Optional[MemoryEntry]: + """Get a memory entry by ID. + + Args: + memory_id: The ID of the memory to retrieve + + Returns: + The memory entry or None if not found + """ + return self._memories.get(memory_id) + + def remove(self, memory_id: str) -> bool: + """Remove a memory entry. + + Args: + memory_id: The ID of the memory to remove + + Returns: + True if removed, False if not found + """ + if memory_id in self._memories: + del self._memories[memory_id] + self._embeddings.pop(memory_id, None) + return True + return False + + def clear(self) -> int: + """Remove all memories. + + Returns: + Number of memories removed + """ + count = len(self._memories) + self._memories.clear() + self._embeddings.clear() + return count + + def count(self) -> int: + """Return the number of memories.""" + return len(self._memories) + + def _calculate_recency_score( + self, + entry: MemoryEntry, + reference_time: Optional[datetime] = None, + ) -> float: + """Calculate recency score based on age and decay function. + + Args: + entry: The memory entry + reference_time: Reference time for age calculation (default: now) + + Returns: + Recency score in [0, 1] + """ + age_days = entry.age_days(reference_time) + + # Skip if too old + if age_days > self.config.max_age_days: + return self.config.decay_floor + + decay_func = self.config.decay_function + + if decay_func == DecayFunction.EXPONENTIAL: + # Calculate lambda from half-life if not explicitly set + lambda_val = self.config.decay_lambda + if self.config.decay_half_life_days > 0: + lambda_val = math.log(2) / self.config.decay_half_life_days + score = math.exp(-lambda_val * age_days) + + elif decay_func == DecayFunction.LINEAR: + # Linear decay over max_age_days + score = max(0, 1 - (age_days / self.config.max_age_days)) + + elif decay_func == DecayFunction.STEP: + # Step function: full score until half-life, then floor + if age_days < self.config.decay_half_life_days: + score = 1.0 + else: + score = self.config.decay_floor + + elif decay_func == DecayFunction.POWER: + # Power decay: 1 / (1 + days)^alpha (alpha = decay_lambda) + alpha = self.config.decay_lambda + score = 1 / ((1 + age_days) ** alpha) + else: + score = 1.0 + + # Apply floor + return max(self.config.decay_floor, score) + + def _calculate_relevancy_score( + self, + entry: MemoryEntry, + query: str, + query_embedding: Optional[List[float]] = None, + ) -> float: + """Calculate relevancy score based on semantic similarity. + + Args: + entry: The memory entry + query: The query text + query_embedding: Pre-computed query embedding (optional) + + Returns: + Relevancy score in [0, 1] + """ + # If we have embeddings, use cosine similarity + entry_embedding = self._embeddings.get(entry.id) or entry.embedding + + if entry_embedding is not None and query_embedding is not None: + return self._cosine_similarity(query_embedding, entry_embedding) + + # Fallback to simple word overlap (Jaccard similarity) + return self._word_overlap_similarity(query, entry.content) + + def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float: + """Calculate cosine similarity between two vectors. + + Args: + vec1: First vector + vec2: Second vector + + Returns: + Cosine similarity in [-1, 1], normalized to [0, 1] + """ + if len(vec1) != len(vec2): + return 0.0 + + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + norm1 = math.sqrt(sum(a * a for a in vec1)) + norm2 = math.sqrt(sum(b * b for b in vec2)) + + if norm1 == 0 or norm2 == 0: + return 0.0 + + similarity = dot_product / (norm1 * norm2) + + # Normalize from [-1, 1] to [0, 1] + return (similarity + 1) / 2 + + def _word_overlap_similarity(self, text1: str, text2: str) -> float: + """Calculate word overlap (Jaccard) similarity. + + Args: + text1: First text + text2: Second text + + Returns: + Jaccard similarity in [0, 1] + """ + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = len(words1 & words2) + union = len(words1 | words2) + + return intersection / union if union > 0 else 0.0 + + def _calculate_importance_score(self, entry: MemoryEntry) -> float: + """Get the importance score for an entry. + + Args: + entry: The memory entry + + Returns: + Importance score in [0, 1] + """ + # If auto_importance is enabled and metadata has return info, calculate + if self.config.auto_importance: + returns = entry.metadata.get("returns") or entry.metadata.get("return") + if returns is not None: + abs_return = abs(float(returns)) + if abs_return >= self.config.return_threshold_critical: + return ImportanceLevel.CRITICAL.value + elif abs_return >= self.config.return_threshold_high: + return ImportanceLevel.HIGH.value + elif abs_return >= 0.01: # 1% + return ImportanceLevel.MEDIUM.value + else: + return ImportanceLevel.LOW.value + + # Use the entry's stored importance + return entry.importance + + def _calculate_combined_score( + self, + recency: float, + relevancy: float, + importance: float, + ) -> float: + """Calculate the combined score using configured weights. + + Args: + recency: Recency score [0, 1] + relevancy: Relevancy score [0, 1] + importance: Importance score [0, 1] + + Returns: + Combined score [0, 1] + """ + weights = self.config.weights.normalized() + return ( + weights.recency * recency + + weights.relevancy * relevancy + + weights.importance * importance + ) + + def score_entry( + self, + entry: MemoryEntry, + query: str, + query_embedding: Optional[List[float]] = None, + reference_time: Optional[datetime] = None, + ) -> ScoredMemory: + """Score a memory entry against a query. + + Args: + entry: The memory entry to score + query: The query text + query_embedding: Pre-computed query embedding (optional) + reference_time: Reference time for recency (default: now) + + Returns: + ScoredMemory with all scores computed + """ + recency = self._calculate_recency_score(entry, reference_time) + relevancy = self._calculate_relevancy_score(entry, query, query_embedding) + importance = self._calculate_importance_score(entry) + combined = self._calculate_combined_score(recency, relevancy, importance) + + return ScoredMemory( + entry=entry, + recency_score=recency, + relevancy_score=relevancy, + importance_score=importance, + combined_score=combined, + ) + + def retrieve( + self, + query: str, + top_k: Optional[int] = None, + min_score: Optional[float] = None, + tags: Optional[List[str]] = None, + reference_time: Optional[datetime] = None, + ) -> List[ScoredMemory]: + """Retrieve relevant memories based on query. + + Args: + query: The query text + top_k: Maximum number of results (default: config.default_top_k) + min_score: Minimum combined score (default: config.score_threshold) + tags: Filter by tags (memories must have at least one matching tag) + reference_time: Reference time for recency (default: now) + + Returns: + List of ScoredMemory, sorted by combined_score descending + """ + if not self._memories: + return [] + + top_k = top_k or self.config.default_top_k + min_score = min_score if min_score is not None else self.config.score_threshold + + # Compute query embedding if we have an embedding function + query_embedding = None + if self.embedding_function is not None: + try: + query_embedding = self.embedding_function(query) + except Exception: + pass + + # Score all memories + scored_memories: List[ScoredMemory] = [] + for entry in self._memories.values(): + # Filter by tags if specified + if tags: + if not any(tag in entry.tags for tag in tags): + continue + + scored = self.score_entry(entry, query, query_embedding, reference_time) + + # Filter by min score + if scored.combined_score >= min_score: + scored_memories.append(scored) + + # Sort by combined score descending + scored_memories.sort(key=lambda x: x.combined_score, reverse=True) + + # Return top_k + return scored_memories[:top_k] + + def retrieve_by_recency( + self, + top_k: Optional[int] = None, + reference_time: Optional[datetime] = None, + ) -> List[MemoryEntry]: + """Retrieve memories sorted by recency only. + + Args: + top_k: Maximum number of results + reference_time: Reference time for recency + + Returns: + List of MemoryEntry, sorted by timestamp descending + """ + top_k = top_k or self.config.default_top_k + ref = reference_time or datetime.now() + + entries = list(self._memories.values()) + entries.sort(key=lambda x: x.timestamp, reverse=True) + + return entries[:top_k] + + def retrieve_by_importance( + self, + top_k: Optional[int] = None, + min_importance: Optional[float] = None, + ) -> List[MemoryEntry]: + """Retrieve memories sorted by importance only. + + Args: + top_k: Maximum number of results + min_importance: Minimum importance score + + Returns: + List of MemoryEntry, sorted by importance descending + """ + top_k = top_k or self.config.default_top_k + min_importance = min_importance or 0.0 + + entries = [ + e for e in self._memories.values() + if self._calculate_importance_score(e) >= min_importance + ] + entries.sort( + key=lambda x: self._calculate_importance_score(x), + reverse=True, + ) + + return entries[:top_k] + + def update_importance(self, memory_id: str, importance: float) -> bool: + """Update the importance score of a memory. + + Args: + memory_id: The ID of the memory to update + importance: New importance score [0, 1] + + Returns: + True if updated, False if not found + """ + if memory_id not in self._memories: + return False + + if not 0.0 <= importance <= 1.0: + raise ValueError(f"Importance must be between 0 and 1, got {importance}") + + self._memories[memory_id].importance = importance + return True + + def get_statistics(self) -> Dict[str, Any]: + """Get statistics about the memory store. + + Returns: + Dictionary with memory statistics + """ + if not self._memories: + return { + "count": 0, + "oldest": None, + "newest": None, + "avg_importance": 0.0, + "importance_distribution": {}, + } + + entries = list(self._memories.values()) + timestamps = [e.timestamp for e in entries] + importances = [e.importance for e in entries] + + # Count importance levels + importance_dist = { + "critical": sum(1 for i in importances if i >= 0.9), + "high": sum(1 for i in importances if 0.7 <= i < 0.9), + "medium": sum(1 for i in importances if 0.4 <= i < 0.7), + "low": sum(1 for i in importances if 0.1 <= i < 0.4), + "minimal": sum(1 for i in importances if i < 0.1), + } + + return { + "count": len(entries), + "oldest": min(timestamps).isoformat(), + "newest": max(timestamps).isoformat(), + "avg_importance": sum(importances) / len(importances), + "importance_distribution": importance_dist, + } + + def to_dict(self) -> Dict[str, Any]: + """Serialize the memory store to a dictionary. + + Returns: + Dictionary representation of the memory store + """ + return { + "memories": [e.to_dict() for e in self._memories.values()], + "config": { + "weights": { + "recency": self.config.weights.recency, + "relevancy": self.config.weights.relevancy, + "importance": self.config.weights.importance, + }, + "decay_function": self.config.decay_function.value, + "decay_lambda": self.config.decay_lambda, + "decay_half_life_days": self.config.decay_half_life_days, + "decay_floor": self.config.decay_floor, + "max_age_days": self.config.max_age_days, + }, + } + + @classmethod + def from_dict( + cls, + data: Dict[str, Any], + embedding_function: Optional[Callable[[str], List[float]]] = None, + ) -> "LayeredMemory": + """Create a LayeredMemory from a dictionary. + + Args: + data: Dictionary representation + embedding_function: Optional embedding function + + Returns: + LayeredMemory instance + """ + config_data = data.get("config", {}) + weights_data = config_data.get("weights", {}) + + config = MemoryConfig( + weights=ScoringWeights( + recency=weights_data.get("recency", 0.3), + relevancy=weights_data.get("relevancy", 0.5), + importance=weights_data.get("importance", 0.2), + ), + decay_function=DecayFunction(config_data.get("decay_function", "exponential")), + decay_lambda=config_data.get("decay_lambda", 0.1), + decay_half_life_days=config_data.get("decay_half_life_days", 7), + decay_floor=config_data.get("decay_floor", 0.1), + max_age_days=config_data.get("max_age_days", 365), + ) + + memory = cls(config=config, embedding_function=embedding_function) + + for entry_data in data.get("memories", []): + entry = MemoryEntry.from_dict(entry_data) + memory.add(entry) + + return memory