TradingAgents/tests/test_memory/test_financial_memory.py

198 lines
7.4 KiB
Python

"""Unit tests for FinancialSituationMemory."""
import pytest
from tradingagents.agents.utils.memory import FinancialSituationMemory
class TestFinancialSituationMemory:
"""Tests for the FinancialSituationMemory class."""
@pytest.mark.unit
def test_init(self):
"""Test memory initialization."""
memory = FinancialSituationMemory("test_memory")
assert memory.name == "test_memory"
assert len(memory.documents) == 0
assert len(memory.recommendations) == 0
assert memory.bm25 is None
@pytest.mark.unit
def test_init_with_config(self):
"""Test memory initialization with config (for API compatibility)."""
memory = FinancialSituationMemory("test_memory", config={"some": "config"})
assert memory.name == "test_memory"
# Config is accepted but not used for BM25
@pytest.mark.unit
def test_add_situations_single(self):
"""Test adding a single situation."""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([("High volatility", "Reduce exposure")])
assert len(memory.documents) == 1
assert len(memory.recommendations) == 1
assert memory.documents[0] == "High volatility"
assert memory.recommendations[0] == "Reduce exposure"
assert memory.bm25 is not None
@pytest.mark.unit
def test_add_situations_multiple(self):
"""Test adding multiple situations."""
memory = FinancialSituationMemory("test_memory")
situations = [
("High volatility in tech sector", "Reduce exposure"),
("Strong earnings report", "Consider buying"),
("Rising interest rates", "Review duration"),
]
memory.add_situations(situations)
assert len(memory.documents) == 3
assert len(memory.recommendations) == 3
assert memory.bm25 is not None
@pytest.mark.unit
def test_add_situations_incremental(self):
"""Test adding situations incrementally."""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([("First situation", "First recommendation")])
memory.add_situations([("Second situation", "Second recommendation")])
assert len(memory.documents) == 2
assert memory.recommendations[0] == "First recommendation"
assert memory.recommendations[1] == "Second recommendation"
@pytest.mark.unit
def test_get_memories_returns_matches(self):
"""Test that get_memories returns matching results."""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([
("High inflation affecting tech stocks", "Consider defensive positions"),
("Strong dollar impacting exports", "Review international exposure"),
])
results = memory.get_memories("inflation concerns in technology sector", n_matches=1)
assert len(results) == 1
assert "similarity_score" in results[0]
assert "matched_situation" in results[0]
assert "recommendation" in results[0]
assert results[0]["matched_situation"] == "High inflation affecting tech stocks"
@pytest.mark.unit
def test_get_memories_multiple_matches(self):
"""Test that get_memories returns multiple matches."""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([
("High inflation affecting tech stocks", "Consider defensive positions"),
("Inflation concerns rising globally", "Review commodity exposure"),
("Strong dollar impacting exports", "Review international exposure"),
])
results = memory.get_memories("inflation worries", n_matches=2)
assert len(results) == 2
# Both inflation-related situations should be in top results
situations = [r["matched_situation"] for r in results]
assert (
"High inflation affecting tech stocks" in situations
or "Inflation concerns rising globally" in situations
)
@pytest.mark.unit
def test_get_memories_empty_returns_empty(self):
"""Test that get_memories on empty memory returns empty list."""
memory = FinancialSituationMemory("test_memory")
results = memory.get_memories("any query", n_matches=1)
assert results == []
@pytest.mark.unit
def test_get_memories_normalized_score(self):
"""Test that similarity scores are computed correctly.
Note: BM25 scores can be negative for documents with low term frequency.
The normalization divides by max_score but doesn't shift negative scores.
"""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([
("High volatility tech sector", "Reduce exposure"),
("Low volatility bonds", "Stable income"),
])
results = memory.get_memories("volatility in tech", n_matches=2)
# Verify we get results with similarity_score field
assert len(results) == 2
for result in results:
assert "similarity_score" in result
# BM25 scores can theoretically be negative, verify it's a number
assert isinstance(result["similarity_score"], float)
@pytest.mark.unit
def test_clear(self):
"""Test that clear empties the memory."""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([("test", "test recommendation")])
assert len(memory.documents) == 1
assert memory.bm25 is not None
memory.clear()
assert len(memory.documents) == 0
assert len(memory.recommendations) == 0
assert memory.bm25 is None
@pytest.mark.unit
def test_get_memories_after_clear(self):
"""Test that get_memories works after clear and re-add."""
memory = FinancialSituationMemory("test_memory")
memory.add_situations([("First", "Rec1")])
memory.clear()
memory.add_situations([("Second", "Rec2")])
results = memory.get_memories("Second", n_matches=1)
assert len(results) == 1
assert results[0]["matched_situation"] == "Second"
@pytest.mark.unit
def test_tokenize_lowercase(self):
"""Test that tokenization lowercases text."""
memory = FinancialSituationMemory("test")
tokens = memory._tokenize("HELLO World")
assert all(token.islower() for token in tokens)
@pytest.mark.unit
def test_tokenize_splits_on_punctuation(self):
"""Test that tokenization splits on punctuation."""
memory = FinancialSituationMemory("test")
tokens = memory._tokenize("hello, world! test.")
assert tokens == ["hello", "world", "test"]
@pytest.mark.unit
def test_tokenize_handles_numbers(self):
"""Test that tokenization handles numbers."""
memory = FinancialSituationMemory("test")
tokens = memory._tokenize("price 123.45 dollars")
assert "123" in tokens
assert "45" in tokens
@pytest.mark.unit
def test_unicode_handling(self):
"""Test that memory handles Unicode content."""
memory = FinancialSituationMemory("test")
memory.add_situations([
("欧洲市场波动加剧", "考虑减少欧洲敞口"),
("日本央行政策调整", "关注汇率变化"),
])
results = memory.get_memories("欧洲市场", n_matches=1)
assert len(results) == 1
assert "欧洲" in results[0]["matched_situation"]