diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 00000000..9ec93706 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,142 @@ +import unittest + +from tradingagents.agents.utils.memory import FinancialSituationMemory + + +class TestFinancialSituationMemory(unittest.TestCase): + """Unit tests for FinancialSituationMemory using BM25 retrieval.""" + + def test_add_and_retrieve_single(self): + """添加一条记录,查询能匹配到.""" + memory = FinancialSituationMemory("test_single") + memory.add_situations([ + ("Federal Reserve raised interest rates by 50 basis points amid high inflation", "Consider reducing duration in bond portfolios. Evaluate floating-rate instruments."), + ]) + + results = memory.get_memories("The Fed is hiking rates to combat inflation", n_matches=1) + + self.assertEqual(len(results), 1) + self.assertIn("interest rates", results[0]["matched_situation"].lower()) + self.assertIsNotNone(results[0]["similarity_score"]) + + def test_empty_memory_returns_empty(self): + """空记忆返回空列表.""" + memory = FinancialSituationMemory("test_empty") + + results = memory.get_memories("Any financial situation", n_matches=3) + + self.assertEqual(results, []) + + def test_multiple_matches_ranked(self): + """多条记录,返回按相关性排序的 top-n.""" + memory = FinancialSituationMemory("test_ranked") + memory.add_situations([ + ("Consumer staples stocks outperforming during economic downturn", "Defensive sectors like utilities and healthcare provide stability."), + ("Technology sector experiencing selloff amid rising Treasury yields", "Reduce exposure to high-growth equities. Consider quality value stocks."), + ("Emerging market currencies depreciating due to strong US dollar", "Hedge currency risk. Reduce EM equity exposure."), + ]) + + results = memory.get_memories("Growth stocks are falling as bond yields rise", n_matches=2) + + self.assertEqual(len(results), 2) + self.assertGreaterEqual(results[0]["similarity_score"], results[1]["similarity_score"]) + + def test_clear_memory(self): + """clear() 后查询返回空.""" + memory = FinancialSituationMemory("test_clear") + memory.add_situations([ + ("Oil prices surged 30% due to OPEC production cuts", "Consider energy sector exposure and inflation hedges."), + ]) + + memory.clear() + results = memory.get_memories("Crude oil is up", n_matches=1) + + self.assertEqual(results, []) + self.assertEqual(len(memory.documents), 0) + + def test_normalized_scores_range(self): + """所有 similarity_score 在 [0, 1] 范围内.""" + memory = FinancialSituationMemory("test_scores") + memory.add_situations([ + ("S&P 500 volatility index spiked to 35 amid geopolitical tensions", "Increase cash allocation. Consider VIX hedging strategies."), + ("Investment grade corporate spreads widening significantly", "Stress test portfolio credit risk. Reduce lower-rated IG holdings."), + ("Real estate investment trusts declining due to higher cap rates", "Evaluate REIT exposure. Focus on industrial and data center REITs."), + ]) + + results = memory.get_memories("Market turbulence from international conflict", n_matches=3) + + self.assertEqual(len(results), 3) + for result in results: + self.assertGreaterEqual(result["similarity_score"], 0.0) + self.assertLessEqual(result["similarity_score"], 1.0) + + def test_n_matches_exceeds_documents(self): + """n_matches > 文档数时不崩溃.""" + memory = FinancialSituationMemory("test_excess") + memory.add_situations([ + ("Bank of Japan unexpectedly adjusted yield curve control policy", "Monitor JGB volatility. Assess yen exposure."), + ]) + + # Request more matches than documents exist + results = memory.get_memories("Japanese bond policy change", n_matches=10) + + self.assertEqual(len(results), 1) # Should return only 1 available match + self.assertIsNotNone(results[0]["similarity_score"]) + + def test_rebuild_index_after_add(self): + """多次 add_situations 后索引正确重建.""" + memory = FinancialSituationMemory("test_rebuild") + + memory.add_situations([ + ("Manufacturing PMI contracted to 48, signaling economic slowdown", "Increase defensive positioning. Monitor leading economic indicators."), + ]) + + memory.add_situations([ + ("Bitcoin dropped below $50,000 amid regulatory uncertainty", "Reduce cryptocurrency exposure. Focus on traditional risk assets."), + ]) + + memory.add_situations([ + ("Investment grade corporate spreads widened to 150 basis points", "Credit quality concerns. Prefer higher-rated issuers."), + ]) + + results = memory.get_memories("Economic contraction signals", n_matches=1) + + self.assertEqual(len(memory.documents), 3) + self.assertEqual(len(results), 1) + self.assertIn("PMI", results[0]["matched_situation"]) + + def test_tokenize_handles_special_chars(self): + """_tokenize 能处理特殊字符.""" + memory = FinancialSituationMemory("test_tokenize") + + tokens = memory._tokenize( + "P/E ratio (TTM) is 25.4x, with EPS @ $4.20 & dividend yield 2.1%" + ) + + self.assertIn("p", tokens) + self.assertIn("e", tokens) + self.assertIn("ttm", tokens) + self.assertIn("25", tokens) + self.assertIn("4", tokens) + self.assertIn("20", tokens) + self.assertIn("eps", tokens) + self.assertIn("dividend", tokens) + self.assertIn("yield", tokens) + + def test_retrieve_with_overlapping_concepts(self): + """包含重叠金融概念的多条记录,验证相关性排序.""" + memory = FinancialSituationMemory("test_concepts") + memory.add_situations([ + ("Federal Reserve maintaining dovish stance with low rates", "Favor growth stocks and long-duration bonds."), + ("Federal Reserve turning hawkish with rate hikes", "Rotate to value stocks and short-duration bonds."), + ("European Central Bank keeping rates unchanged", "Monitor EUR currency exposure."), + ]) + + results = memory.get_memories("US central bank raising interest rates", n_matches=2) + + self.assertEqual(len(results), 2) + self.assertGreater(results[0]["similarity_score"], results[1]["similarity_score"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index d278b3c3..67df19fe 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -78,7 +78,7 @@ class FinancialSituationMemory: # Build results results = [] - max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores + max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0 # Normalize scores for idx in top_indices: # Normalize score to 0-1 range for consistency