TradingAgents/tests/test_memory_persistence.py

293 lines
11 KiB
Python

"""Tests for FinancialSituationMemory JSON file persistence.
Covers:
- RAM-only mode (no config / config without key) behaves identically to before
- Persistence: add_situations → restart → memories survive
- Persistence: clear → restart → file reflects empty state
- Atomic write: .tmp → .json rename
- Corrupt / missing file is handled gracefully
- BM25 index is rebuilt after loading from disk
- Multiple memory instances sharing the same dir don't collide
- Default config includes memory_persist_dir key
- TradingAgentsGraph passes config through to all five memory instances
"""
import json
import os
import pathlib
import tempfile
import unittest
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.default_config import DEFAULT_CONFIG
SAMPLE_DATA = [
(
"High inflation rate with rising interest rates",
"Consider defensive sectors like utilities.",
),
(
"Tech sector volatility with institutional selling pressure",
"Reduce high-growth tech exposure.",
),
(
"Strong dollar affecting emerging markets",
"Hedge currency exposure in international positions.",
),
]
class TestRamOnlyMode(unittest.TestCase):
"""Ensure RAM-only mode (no persistence) is fully backward-compatible."""
def test_no_config(self):
mem = FinancialSituationMemory("test")
self.assertIsNone(mem._persist_path)
mem.add_situations(SAMPLE_DATA[:1])
self.assertEqual(len(mem.documents), 1)
def test_config_without_persist_key(self):
mem = FinancialSituationMemory("test", config={"some_other_key": True})
self.assertIsNone(mem._persist_path)
def test_config_with_none_persist_dir(self):
mem = FinancialSituationMemory("test", config={"memory_persist_dir": None})
self.assertIsNone(mem._persist_path)
def test_config_with_empty_string_persist_dir(self):
mem = FinancialSituationMemory("test", config={"memory_persist_dir": ""})
self.assertIsNone(mem._persist_path)
def test_ram_only_data_lost_on_new_instance(self):
mem1 = FinancialSituationMemory("test")
mem1.add_situations(SAMPLE_DATA)
self.assertEqual(len(mem1.documents), 3)
mem2 = FinancialSituationMemory("test")
self.assertEqual(len(mem2.documents), 0)
class TestPersistence(unittest.TestCase):
"""Core persistence round-trip tests."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_file_created_on_add(self):
mem = FinancialSituationMemory("bull_memory", config=self.config)
self.assertFalse(
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
"File should not exist before adding data",
)
mem.add_situations(SAMPLE_DATA[:1])
self.assertTrue(
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
"File should be created after add_situations",
)
def test_round_trip(self):
"""Data survives a full destroy-and-recreate cycle."""
mem1 = FinancialSituationMemory("rt", config=self.config)
mem1.add_situations(SAMPLE_DATA)
self.assertEqual(len(mem1.documents), 3)
# Destroy first instance, create a new one — simulates process restart
del mem1
mem2 = FinancialSituationMemory("rt", config=self.config)
self.assertEqual(len(mem2.documents), 3)
self.assertEqual(mem2.documents[0], SAMPLE_DATA[0][0])
self.assertEqual(mem2.recommendations[1], SAMPLE_DATA[1][1])
def test_incremental_add_persists(self):
"""Multiple add_situations calls accumulate correctly on disk."""
mem = FinancialSituationMemory("inc", config=self.config)
mem.add_situations(SAMPLE_DATA[:1])
mem.add_situations(SAMPLE_DATA[1:2])
del mem
mem2 = FinancialSituationMemory("inc", config=self.config)
self.assertEqual(len(mem2.documents), 2)
def test_clear_persists_empty(self):
mem = FinancialSituationMemory("clr", config=self.config)
mem.add_situations(SAMPLE_DATA)
mem.clear()
del mem
mem2 = FinancialSituationMemory("clr", config=self.config)
self.assertEqual(len(mem2.documents), 0)
self.assertEqual(len(mem2.recommendations), 0)
def test_bm25_rebuilt_after_load(self):
"""BM25 index should work after loading from disk."""
mem1 = FinancialSituationMemory("bm25", config=self.config)
mem1.add_situations(SAMPLE_DATA)
del mem1
mem2 = FinancialSituationMemory("bm25", config=self.config)
results = mem2.get_memories("rising interest rates inflation", n_matches=1)
self.assertEqual(len(results), 1)
self.assertIn("inflation", results[0]["matched_situation"].lower())
def test_json_file_content_valid(self):
"""Persisted JSON has the expected schema."""
mem = FinancialSituationMemory("schema", config=self.config)
mem.add_situations(SAMPLE_DATA[:2])
fp = pathlib.Path(self.tmpdir) / "schema.json"
data = json.loads(fp.read_text(encoding="utf-8"))
self.assertIn("situations", data)
self.assertIn("recommendations", data)
self.assertEqual(len(data["situations"]), 2)
self.assertEqual(len(data["recommendations"]), 2)
def test_unicode_round_trip(self):
"""Non-ASCII data (e.g. CJK) survives persistence."""
mem = FinancialSituationMemory("uni", config=self.config)
mem.add_situations([("通胀上升,利率攀升", "考虑防御性板块")])
del mem
mem2 = FinancialSituationMemory("uni", config=self.config)
self.assertEqual(mem2.documents[0], "通胀上升,利率攀升")
self.assertEqual(mem2.recommendations[0], "考虑防御性板块")
class TestEdgeCases(unittest.TestCase):
"""Graceful handling of corrupt / missing files."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_corrupt_json_starts_fresh(self):
fp = pathlib.Path(self.tmpdir) / "bad.json"
fp.write_text("NOT VALID JSON {{{", encoding="utf-8")
mem = FinancialSituationMemory("bad", config=self.config)
self.assertEqual(len(mem.documents), 0, "Should start fresh on corrupt file")
def test_missing_keys_starts_fresh(self):
fp = pathlib.Path(self.tmpdir) / "partial.json"
fp.write_text(json.dumps({"unrelated": True}), encoding="utf-8")
mem = FinancialSituationMemory("partial", config=self.config)
self.assertEqual(len(mem.documents), 0)
def test_mismatched_lengths_truncates(self):
"""If situations and recommendations have different lengths, zip truncates."""
fp = pathlib.Path(self.tmpdir) / "mismatch.json"
fp.write_text(json.dumps({
"situations": ["s1", "s2", "s3"],
"recommendations": ["r1", "r2"],
}), encoding="utf-8")
mem = FinancialSituationMemory("mismatch", config=self.config)
self.assertEqual(len(mem.documents), 2)
self.assertEqual(len(mem.recommendations), 2)
def test_parent_dirs_created(self):
nested = os.path.join(self.tmpdir, "a", "b", "c")
config = {"memory_persist_dir": nested}
mem = FinancialSituationMemory("nested", config=config)
mem.add_situations(SAMPLE_DATA[:1])
self.assertTrue(pathlib.Path(nested).is_dir())
self.assertTrue((pathlib.Path(nested) / "nested.json").exists())
def test_tilde_expansion(self):
"""~ in persist dir is expanded."""
config = {"memory_persist_dir": "~/.__test_tradingagents_mem__"}
try:
mem = FinancialSituationMemory("tilde", config=config)
self.assertFalse(str(mem._persist_path).startswith("~"))
self.assertIn(os.path.expanduser("~"), str(mem._persist_path))
finally:
import shutil
shutil.rmtree(
os.path.expanduser("~/.__test_tradingagents_mem__"),
ignore_errors=True,
)
class TestMultipleInstances(unittest.TestCase):
"""Multiple memory instances in the same directory don't collide."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_independent_files(self):
bull = FinancialSituationMemory("bull_memory", config=self.config)
bear = FinancialSituationMemory("bear_memory", config=self.config)
bull.add_situations(SAMPLE_DATA[:1])
bear.add_situations(SAMPLE_DATA[1:2])
del bull, bear
bull2 = FinancialSituationMemory("bull_memory", config=self.config)
bear2 = FinancialSituationMemory("bear_memory", config=self.config)
self.assertEqual(len(bull2.documents), 1)
self.assertEqual(len(bear2.documents), 1)
self.assertNotEqual(bull2.documents[0], bear2.documents[0])
class TestDefaultConfig(unittest.TestCase):
"""Verify the default config includes the new key."""
def test_key_exists(self):
self.assertIn("memory_persist_dir", DEFAULT_CONFIG)
def test_default_is_none(self):
self.assertIsNone(
DEFAULT_CONFIG["memory_persist_dir"],
"Default should be None (RAM-only) for backward compatibility",
)
class TestTradingGraphIntegration(unittest.TestCase):
"""Verify TradingAgentsGraph passes config to all memory instances.
This is a source-code audit test — it does NOT instantiate the graph
(which requires LLM API keys and heavy dependencies), but instead
inspects the constructor to confirm the config dict is forwarded.
"""
def test_config_forwarded_to_memory_constructors(self):
"""All five FinancialSituationMemory() calls receive self.config."""
import inspect
from tradingagents.graph.trading_graph import TradingAgentsGraph
source = inspect.getsource(TradingAgentsGraph.__init__)
memory_names = [
"bull_memory",
"bear_memory",
"trader_memory",
"invest_judge_memory",
"portfolio_manager_memory",
]
for name in memory_names:
# Expect: FinancialSituationMemory("name", self.config)
self.assertIn(
f'FinancialSituationMemory("{name}", self.config)',
source,
f"TradingAgentsGraph.__init__ should pass self.config to {name}",
)
if __name__ == "__main__":
unittest.main()