293 lines
11 KiB
Python
293 lines
11 KiB
Python
"""Tests for FinancialSituationMemory JSON file persistence.
|
|
|
|
Covers:
|
|
- RAM-only mode (no config / config without key) behaves identically to before
|
|
- Persistence: add_situations → restart → memories survive
|
|
- Persistence: clear → restart → file reflects empty state
|
|
- Atomic write: .tmp → .json rename
|
|
- Corrupt / missing file is handled gracefully
|
|
- BM25 index is rebuilt after loading from disk
|
|
- Multiple memory instances sharing the same dir don't collide
|
|
- Default config includes memory_persist_dir key
|
|
- TradingAgentsGraph passes config through to all five memory instances
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import tempfile
|
|
import unittest
|
|
|
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
|
|
|
|
SAMPLE_DATA = [
|
|
(
|
|
"High inflation rate with rising interest rates",
|
|
"Consider defensive sectors like utilities.",
|
|
),
|
|
(
|
|
"Tech sector volatility with institutional selling pressure",
|
|
"Reduce high-growth tech exposure.",
|
|
),
|
|
(
|
|
"Strong dollar affecting emerging markets",
|
|
"Hedge currency exposure in international positions.",
|
|
),
|
|
]
|
|
|
|
|
|
class TestRamOnlyMode(unittest.TestCase):
|
|
"""Ensure RAM-only mode (no persistence) is fully backward-compatible."""
|
|
|
|
def test_no_config(self):
|
|
mem = FinancialSituationMemory("test")
|
|
self.assertIsNone(mem._persist_path)
|
|
mem.add_situations(SAMPLE_DATA[:1])
|
|
self.assertEqual(len(mem.documents), 1)
|
|
|
|
def test_config_without_persist_key(self):
|
|
mem = FinancialSituationMemory("test", config={"some_other_key": True})
|
|
self.assertIsNone(mem._persist_path)
|
|
|
|
def test_config_with_none_persist_dir(self):
|
|
mem = FinancialSituationMemory("test", config={"memory_persist_dir": None})
|
|
self.assertIsNone(mem._persist_path)
|
|
|
|
def test_config_with_empty_string_persist_dir(self):
|
|
mem = FinancialSituationMemory("test", config={"memory_persist_dir": ""})
|
|
self.assertIsNone(mem._persist_path)
|
|
|
|
def test_ram_only_data_lost_on_new_instance(self):
|
|
mem1 = FinancialSituationMemory("test")
|
|
mem1.add_situations(SAMPLE_DATA)
|
|
self.assertEqual(len(mem1.documents), 3)
|
|
|
|
mem2 = FinancialSituationMemory("test")
|
|
self.assertEqual(len(mem2.documents), 0)
|
|
|
|
|
|
class TestPersistence(unittest.TestCase):
|
|
"""Core persistence round-trip tests."""
|
|
|
|
def setUp(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
self.config = {"memory_persist_dir": self.tmpdir}
|
|
|
|
def tearDown(self):
|
|
import shutil
|
|
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
|
|
|
def test_file_created_on_add(self):
|
|
mem = FinancialSituationMemory("bull_memory", config=self.config)
|
|
self.assertFalse(
|
|
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
|
|
"File should not exist before adding data",
|
|
)
|
|
mem.add_situations(SAMPLE_DATA[:1])
|
|
self.assertTrue(
|
|
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
|
|
"File should be created after add_situations",
|
|
)
|
|
|
|
def test_round_trip(self):
|
|
"""Data survives a full destroy-and-recreate cycle."""
|
|
mem1 = FinancialSituationMemory("rt", config=self.config)
|
|
mem1.add_situations(SAMPLE_DATA)
|
|
self.assertEqual(len(mem1.documents), 3)
|
|
|
|
# Destroy first instance, create a new one — simulates process restart
|
|
del mem1
|
|
mem2 = FinancialSituationMemory("rt", config=self.config)
|
|
self.assertEqual(len(mem2.documents), 3)
|
|
self.assertEqual(mem2.documents[0], SAMPLE_DATA[0][0])
|
|
self.assertEqual(mem2.recommendations[1], SAMPLE_DATA[1][1])
|
|
|
|
def test_incremental_add_persists(self):
|
|
"""Multiple add_situations calls accumulate correctly on disk."""
|
|
mem = FinancialSituationMemory("inc", config=self.config)
|
|
mem.add_situations(SAMPLE_DATA[:1])
|
|
mem.add_situations(SAMPLE_DATA[1:2])
|
|
|
|
del mem
|
|
mem2 = FinancialSituationMemory("inc", config=self.config)
|
|
self.assertEqual(len(mem2.documents), 2)
|
|
|
|
def test_clear_persists_empty(self):
|
|
mem = FinancialSituationMemory("clr", config=self.config)
|
|
mem.add_situations(SAMPLE_DATA)
|
|
mem.clear()
|
|
|
|
del mem
|
|
mem2 = FinancialSituationMemory("clr", config=self.config)
|
|
self.assertEqual(len(mem2.documents), 0)
|
|
self.assertEqual(len(mem2.recommendations), 0)
|
|
|
|
def test_bm25_rebuilt_after_load(self):
|
|
"""BM25 index should work after loading from disk."""
|
|
mem1 = FinancialSituationMemory("bm25", config=self.config)
|
|
mem1.add_situations(SAMPLE_DATA)
|
|
|
|
del mem1
|
|
mem2 = FinancialSituationMemory("bm25", config=self.config)
|
|
results = mem2.get_memories("rising interest rates inflation", n_matches=1)
|
|
self.assertEqual(len(results), 1)
|
|
self.assertIn("inflation", results[0]["matched_situation"].lower())
|
|
|
|
def test_json_file_content_valid(self):
|
|
"""Persisted JSON has the expected schema."""
|
|
mem = FinancialSituationMemory("schema", config=self.config)
|
|
mem.add_situations(SAMPLE_DATA[:2])
|
|
|
|
fp = pathlib.Path(self.tmpdir) / "schema.json"
|
|
data = json.loads(fp.read_text(encoding="utf-8"))
|
|
self.assertIn("situations", data)
|
|
self.assertIn("recommendations", data)
|
|
self.assertEqual(len(data["situations"]), 2)
|
|
self.assertEqual(len(data["recommendations"]), 2)
|
|
|
|
def test_unicode_round_trip(self):
|
|
"""Non-ASCII data (e.g. CJK) survives persistence."""
|
|
mem = FinancialSituationMemory("uni", config=self.config)
|
|
mem.add_situations([("通胀上升,利率攀升", "考虑防御性板块")])
|
|
|
|
del mem
|
|
mem2 = FinancialSituationMemory("uni", config=self.config)
|
|
self.assertEqual(mem2.documents[0], "通胀上升,利率攀升")
|
|
self.assertEqual(mem2.recommendations[0], "考虑防御性板块")
|
|
|
|
|
|
class TestEdgeCases(unittest.TestCase):
|
|
"""Graceful handling of corrupt / missing files."""
|
|
|
|
def setUp(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
self.config = {"memory_persist_dir": self.tmpdir}
|
|
|
|
def tearDown(self):
|
|
import shutil
|
|
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
|
|
|
def test_corrupt_json_starts_fresh(self):
|
|
fp = pathlib.Path(self.tmpdir) / "bad.json"
|
|
fp.write_text("NOT VALID JSON {{{", encoding="utf-8")
|
|
mem = FinancialSituationMemory("bad", config=self.config)
|
|
self.assertEqual(len(mem.documents), 0, "Should start fresh on corrupt file")
|
|
|
|
def test_missing_keys_starts_fresh(self):
|
|
fp = pathlib.Path(self.tmpdir) / "partial.json"
|
|
fp.write_text(json.dumps({"unrelated": True}), encoding="utf-8")
|
|
mem = FinancialSituationMemory("partial", config=self.config)
|
|
self.assertEqual(len(mem.documents), 0)
|
|
|
|
def test_mismatched_lengths_truncates(self):
|
|
"""If situations and recommendations have different lengths, zip truncates."""
|
|
fp = pathlib.Path(self.tmpdir) / "mismatch.json"
|
|
fp.write_text(json.dumps({
|
|
"situations": ["s1", "s2", "s3"],
|
|
"recommendations": ["r1", "r2"],
|
|
}), encoding="utf-8")
|
|
mem = FinancialSituationMemory("mismatch", config=self.config)
|
|
self.assertEqual(len(mem.documents), 2)
|
|
self.assertEqual(len(mem.recommendations), 2)
|
|
|
|
def test_parent_dirs_created(self):
|
|
nested = os.path.join(self.tmpdir, "a", "b", "c")
|
|
config = {"memory_persist_dir": nested}
|
|
mem = FinancialSituationMemory("nested", config=config)
|
|
mem.add_situations(SAMPLE_DATA[:1])
|
|
self.assertTrue(pathlib.Path(nested).is_dir())
|
|
self.assertTrue((pathlib.Path(nested) / "nested.json").exists())
|
|
|
|
def test_tilde_expansion(self):
|
|
"""~ in persist dir is expanded."""
|
|
config = {"memory_persist_dir": "~/.__test_tradingagents_mem__"}
|
|
try:
|
|
mem = FinancialSituationMemory("tilde", config=config)
|
|
self.assertFalse(str(mem._persist_path).startswith("~"))
|
|
self.assertIn(os.path.expanduser("~"), str(mem._persist_path))
|
|
finally:
|
|
import shutil
|
|
shutil.rmtree(
|
|
os.path.expanduser("~/.__test_tradingagents_mem__"),
|
|
ignore_errors=True,
|
|
)
|
|
|
|
|
|
class TestMultipleInstances(unittest.TestCase):
|
|
"""Multiple memory instances in the same directory don't collide."""
|
|
|
|
def setUp(self):
|
|
self.tmpdir = tempfile.mkdtemp()
|
|
self.config = {"memory_persist_dir": self.tmpdir}
|
|
|
|
def tearDown(self):
|
|
import shutil
|
|
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
|
|
|
def test_independent_files(self):
|
|
bull = FinancialSituationMemory("bull_memory", config=self.config)
|
|
bear = FinancialSituationMemory("bear_memory", config=self.config)
|
|
|
|
bull.add_situations(SAMPLE_DATA[:1])
|
|
bear.add_situations(SAMPLE_DATA[1:2])
|
|
|
|
del bull, bear
|
|
|
|
bull2 = FinancialSituationMemory("bull_memory", config=self.config)
|
|
bear2 = FinancialSituationMemory("bear_memory", config=self.config)
|
|
|
|
self.assertEqual(len(bull2.documents), 1)
|
|
self.assertEqual(len(bear2.documents), 1)
|
|
self.assertNotEqual(bull2.documents[0], bear2.documents[0])
|
|
|
|
|
|
class TestDefaultConfig(unittest.TestCase):
|
|
"""Verify the default config includes the new key."""
|
|
|
|
def test_key_exists(self):
|
|
self.assertIn("memory_persist_dir", DEFAULT_CONFIG)
|
|
|
|
def test_default_is_none(self):
|
|
self.assertIsNone(
|
|
DEFAULT_CONFIG["memory_persist_dir"],
|
|
"Default should be None (RAM-only) for backward compatibility",
|
|
)
|
|
|
|
|
|
class TestTradingGraphIntegration(unittest.TestCase):
|
|
"""Verify TradingAgentsGraph passes config to all memory instances.
|
|
|
|
This is a source-code audit test — it does NOT instantiate the graph
|
|
(which requires LLM API keys and heavy dependencies), but instead
|
|
inspects the constructor to confirm the config dict is forwarded.
|
|
"""
|
|
|
|
def test_config_forwarded_to_memory_constructors(self):
|
|
"""All five FinancialSituationMemory() calls receive self.config."""
|
|
import inspect
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
source = inspect.getsource(TradingAgentsGraph.__init__)
|
|
|
|
memory_names = [
|
|
"bull_memory",
|
|
"bear_memory",
|
|
"trader_memory",
|
|
"invest_judge_memory",
|
|
"portfolio_manager_memory",
|
|
]
|
|
|
|
for name in memory_names:
|
|
# Expect: FinancialSituationMemory("name", self.config)
|
|
self.assertIn(
|
|
f'FinancialSituationMemory("{name}", self.config)',
|
|
source,
|
|
f"TradingAgentsGraph.__init__ should pass self.config to {name}",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|