198 lines
7.4 KiB
Python
198 lines
7.4 KiB
Python
"""Unit tests for FinancialSituationMemory."""
|
|
|
|
import pytest
|
|
|
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
|
|
|
|
|
class TestFinancialSituationMemory:
|
|
"""Tests for the FinancialSituationMemory class."""
|
|
|
|
@pytest.mark.unit
|
|
def test_init(self):
|
|
"""Test memory initialization."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
assert memory.name == "test_memory"
|
|
assert len(memory.documents) == 0
|
|
assert len(memory.recommendations) == 0
|
|
assert memory.bm25 is None
|
|
|
|
@pytest.mark.unit
|
|
def test_init_with_config(self):
|
|
"""Test memory initialization with config (for API compatibility)."""
|
|
memory = FinancialSituationMemory("test_memory", config={"some": "config"})
|
|
assert memory.name == "test_memory"
|
|
# Config is accepted but not used for BM25
|
|
|
|
@pytest.mark.unit
|
|
def test_add_situations_single(self):
|
|
"""Test adding a single situation."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([("High volatility", "Reduce exposure")])
|
|
|
|
assert len(memory.documents) == 1
|
|
assert len(memory.recommendations) == 1
|
|
assert memory.documents[0] == "High volatility"
|
|
assert memory.recommendations[0] == "Reduce exposure"
|
|
assert memory.bm25 is not None
|
|
|
|
@pytest.mark.unit
|
|
def test_add_situations_multiple(self):
|
|
"""Test adding multiple situations."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
situations = [
|
|
("High volatility in tech sector", "Reduce exposure"),
|
|
("Strong earnings report", "Consider buying"),
|
|
("Rising interest rates", "Review duration"),
|
|
]
|
|
memory.add_situations(situations)
|
|
|
|
assert len(memory.documents) == 3
|
|
assert len(memory.recommendations) == 3
|
|
assert memory.bm25 is not None
|
|
|
|
@pytest.mark.unit
|
|
def test_add_situations_incremental(self):
|
|
"""Test adding situations incrementally."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([("First situation", "First recommendation")])
|
|
memory.add_situations([("Second situation", "Second recommendation")])
|
|
|
|
assert len(memory.documents) == 2
|
|
assert memory.recommendations[0] == "First recommendation"
|
|
assert memory.recommendations[1] == "Second recommendation"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_memories_returns_matches(self):
|
|
"""Test that get_memories returns matching results."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([
|
|
("High inflation affecting tech stocks", "Consider defensive positions"),
|
|
("Strong dollar impacting exports", "Review international exposure"),
|
|
])
|
|
|
|
results = memory.get_memories("inflation concerns in technology sector", n_matches=1)
|
|
|
|
assert len(results) == 1
|
|
assert "similarity_score" in results[0]
|
|
assert "matched_situation" in results[0]
|
|
assert "recommendation" in results[0]
|
|
assert results[0]["matched_situation"] == "High inflation affecting tech stocks"
|
|
|
|
@pytest.mark.unit
|
|
def test_get_memories_multiple_matches(self):
|
|
"""Test that get_memories returns multiple matches."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([
|
|
("High inflation affecting tech stocks", "Consider defensive positions"),
|
|
("Inflation concerns rising globally", "Review commodity exposure"),
|
|
("Strong dollar impacting exports", "Review international exposure"),
|
|
])
|
|
|
|
results = memory.get_memories("inflation worries", n_matches=2)
|
|
|
|
assert len(results) == 2
|
|
# Both inflation-related situations should be in top results
|
|
situations = [r["matched_situation"] for r in results]
|
|
assert (
|
|
"High inflation affecting tech stocks" in situations
|
|
or "Inflation concerns rising globally" in situations
|
|
)
|
|
|
|
@pytest.mark.unit
|
|
def test_get_memories_empty_returns_empty(self):
|
|
"""Test that get_memories on empty memory returns empty list."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
results = memory.get_memories("any query", n_matches=1)
|
|
|
|
assert results == []
|
|
|
|
@pytest.mark.unit
|
|
def test_get_memories_normalized_score(self):
|
|
"""Test that similarity scores are computed correctly.
|
|
|
|
Note: BM25 scores can be negative for documents with low term frequency.
|
|
The normalization divides by max_score but doesn't shift negative scores.
|
|
"""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([
|
|
("High volatility tech sector", "Reduce exposure"),
|
|
("Low volatility bonds", "Stable income"),
|
|
])
|
|
|
|
results = memory.get_memories("volatility in tech", n_matches=2)
|
|
|
|
# Verify we get results with similarity_score field
|
|
assert len(results) == 2
|
|
for result in results:
|
|
assert "similarity_score" in result
|
|
# BM25 scores can theoretically be negative, verify it's a number
|
|
assert isinstance(result["similarity_score"], float)
|
|
|
|
@pytest.mark.unit
|
|
def test_clear(self):
|
|
"""Test that clear empties the memory."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([("test", "test recommendation")])
|
|
|
|
assert len(memory.documents) == 1
|
|
assert memory.bm25 is not None
|
|
|
|
memory.clear()
|
|
|
|
assert len(memory.documents) == 0
|
|
assert len(memory.recommendations) == 0
|
|
assert memory.bm25 is None
|
|
|
|
@pytest.mark.unit
|
|
def test_get_memories_after_clear(self):
|
|
"""Test that get_memories works after clear and re-add."""
|
|
memory = FinancialSituationMemory("test_memory")
|
|
memory.add_situations([("First", "Rec1")])
|
|
memory.clear()
|
|
memory.add_situations([("Second", "Rec2")])
|
|
|
|
results = memory.get_memories("Second", n_matches=1)
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["matched_situation"] == "Second"
|
|
|
|
@pytest.mark.unit
|
|
def test_tokenize_lowercase(self):
|
|
"""Test that tokenization lowercases text."""
|
|
memory = FinancialSituationMemory("test")
|
|
tokens = memory._tokenize("HELLO World")
|
|
|
|
assert all(token.islower() for token in tokens)
|
|
|
|
@pytest.mark.unit
|
|
def test_tokenize_splits_on_punctuation(self):
|
|
"""Test that tokenization splits on punctuation."""
|
|
memory = FinancialSituationMemory("test")
|
|
tokens = memory._tokenize("hello, world! test.")
|
|
|
|
assert tokens == ["hello", "world", "test"]
|
|
|
|
@pytest.mark.unit
|
|
def test_tokenize_handles_numbers(self):
|
|
"""Test that tokenization handles numbers."""
|
|
memory = FinancialSituationMemory("test")
|
|
tokens = memory._tokenize("price 123.45 dollars")
|
|
|
|
assert "123" in tokens
|
|
assert "45" in tokens
|
|
|
|
@pytest.mark.unit
|
|
def test_unicode_handling(self):
|
|
"""Test that memory handles Unicode content."""
|
|
memory = FinancialSituationMemory("test")
|
|
memory.add_situations([
|
|
("欧洲市场波动加剧", "考虑减少欧洲敞口"),
|
|
("日本央行政策调整", "关注汇率变化"),
|
|
])
|
|
|
|
results = memory.get_memories("欧洲市场", n_matches=1)
|
|
|
|
assert len(results) == 1
|
|
assert "欧洲" in results[0]["matched_situation"]
|