From 1e2164ec785b666d44abaaecff9c835d2c79347b Mon Sep 17 00:00:00 2001 From: voidborne-d Date: Fri, 17 Apr 2026 15:54:06 +0000 Subject: [PATCH] fix: add JSON file persistence to FinancialSituationMemory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FinancialSituationMemory stored all learned lessons in RAM-only Python lists with no persistence layer. Every process restart or new TradingAgentsGraph() instance wiped all memory, making reflect_and_remember() useless in practice — especially in server/API deployments where a new graph is created per request. Changes: 1. memory.py — add optional JSON file persistence: - New config key `memory_persist_dir`: when set to a directory path, each memory instance writes `/.json` on every mutation (add_situations / clear) and loads it on construction. - When unset or None (the default), behaviour is identical to before (RAM-only) — fully backward compatible. - Atomic-ish writes via .tmp → rename to avoid corruption on crash. - Graceful handling of corrupt / missing / partial JSON files. - Tilde expansion (`~/...`) and automatic parent directory creation. 2. default_config.py — add `memory_persist_dir: None` to DEFAULT_CONFIG. 3. main.py — enable persistence in the example and improve reflect_and_remember documentation comment. 4. tests/test_memory_persistence.py — 21 regression tests covering: - RAM-only backward compatibility (5 tests) - Persistence round-trip, incremental add, clear, BM25 rebuild, JSON schema, Unicode (7 tests) - Edge cases: corrupt JSON, missing keys, mismatched lengths, nested directory creation, tilde expansion (5 tests) - Multiple instances sharing same directory (1 test) - Default config key existence (2 tests) - Source audit: TradingAgentsGraph passes config to all 5 memories (1 test) Closes #563 --- main.py | 13 +- tests/test_memory_persistence.py | 292 +++++++++++++++++++++++++++ tradingagents/agents/utils/memory.py | 74 ++++++- tradingagents/default_config.py | 2 + 4 files changed, 377 insertions(+), 4 deletions(-) create mode 100644 tests/test_memory_persistence.py diff --git a/main.py b/main.py index c94fde32..bca41dc3 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,10 @@ config["data_vendors"] = { "news_data": "yfinance", # Options: alpha_vantage, yfinance } +# Enable memory persistence so lessons survive restarts (optional). +# Set to None or omit to keep the default RAM-only behaviour. +config["memory_persist_dir"] = "~/.tradingagents/memory" + # Initialize with custom config ta = TradingAgentsGraph(debug=True, config=config) @@ -27,5 +31,10 @@ ta = TradingAgentsGraph(debug=True, config=config) _, decision = ta.propagate("NVDA", "2024-05-10") print(decision) -# Memorize mistakes and reflect -# ta.reflect_and_remember(1000) # parameter is the position returns +# Reflect on the decision after observing actual returns. +# Call this once the position closes and the P&L is known. +# The signed float indicates outcome: positive = correct signal, +# negative = incorrect signal. Lessons are persisted when +# memory_persist_dir is set, so the next TradingAgentsGraph +# instance will load them automatically. +# ta.reflect_and_remember(0.03) # e.g. 3 % gain diff --git a/tests/test_memory_persistence.py b/tests/test_memory_persistence.py new file mode 100644 index 00000000..5d1afea8 --- /dev/null +++ b/tests/test_memory_persistence.py @@ -0,0 +1,292 @@ +"""Tests for FinancialSituationMemory JSON file persistence. + +Covers: + - RAM-only mode (no config / config without key) behaves identically to before + - Persistence: add_situations → restart → memories survive + - Persistence: clear → restart → file reflects empty state + - Atomic write: .tmp → .json rename + - Corrupt / missing file is handled gracefully + - BM25 index is rebuilt after loading from disk + - Multiple memory instances sharing the same dir don't collide + - Default config includes memory_persist_dir key + - TradingAgentsGraph passes config through to all five memory instances +""" + +import json +import os +import pathlib +import tempfile +import unittest + +from tradingagents.agents.utils.memory import FinancialSituationMemory +from tradingagents.default_config import DEFAULT_CONFIG + + +SAMPLE_DATA = [ + ( + "High inflation rate with rising interest rates", + "Consider defensive sectors like utilities.", + ), + ( + "Tech sector volatility with institutional selling pressure", + "Reduce high-growth tech exposure.", + ), + ( + "Strong dollar affecting emerging markets", + "Hedge currency exposure in international positions.", + ), +] + + +class TestRamOnlyMode(unittest.TestCase): + """Ensure RAM-only mode (no persistence) is fully backward-compatible.""" + + def test_no_config(self): + mem = FinancialSituationMemory("test") + self.assertIsNone(mem._persist_path) + mem.add_situations(SAMPLE_DATA[:1]) + self.assertEqual(len(mem.documents), 1) + + def test_config_without_persist_key(self): + mem = FinancialSituationMemory("test", config={"some_other_key": True}) + self.assertIsNone(mem._persist_path) + + def test_config_with_none_persist_dir(self): + mem = FinancialSituationMemory("test", config={"memory_persist_dir": None}) + self.assertIsNone(mem._persist_path) + + def test_config_with_empty_string_persist_dir(self): + mem = FinancialSituationMemory("test", config={"memory_persist_dir": ""}) + self.assertIsNone(mem._persist_path) + + def test_ram_only_data_lost_on_new_instance(self): + mem1 = FinancialSituationMemory("test") + mem1.add_situations(SAMPLE_DATA) + self.assertEqual(len(mem1.documents), 3) + + mem2 = FinancialSituationMemory("test") + self.assertEqual(len(mem2.documents), 0) + + +class TestPersistence(unittest.TestCase): + """Core persistence round-trip tests.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.config = {"memory_persist_dir": self.tmpdir} + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_file_created_on_add(self): + mem = FinancialSituationMemory("bull_memory", config=self.config) + self.assertFalse( + (pathlib.Path(self.tmpdir) / "bull_memory.json").exists(), + "File should not exist before adding data", + ) + mem.add_situations(SAMPLE_DATA[:1]) + self.assertTrue( + (pathlib.Path(self.tmpdir) / "bull_memory.json").exists(), + "File should be created after add_situations", + ) + + def test_round_trip(self): + """Data survives a full destroy-and-recreate cycle.""" + mem1 = FinancialSituationMemory("rt", config=self.config) + mem1.add_situations(SAMPLE_DATA) + self.assertEqual(len(mem1.documents), 3) + + # Destroy first instance, create a new one — simulates process restart + del mem1 + mem2 = FinancialSituationMemory("rt", config=self.config) + self.assertEqual(len(mem2.documents), 3) + self.assertEqual(mem2.documents[0], SAMPLE_DATA[0][0]) + self.assertEqual(mem2.recommendations[1], SAMPLE_DATA[1][1]) + + def test_incremental_add_persists(self): + """Multiple add_situations calls accumulate correctly on disk.""" + mem = FinancialSituationMemory("inc", config=self.config) + mem.add_situations(SAMPLE_DATA[:1]) + mem.add_situations(SAMPLE_DATA[1:2]) + + del mem + mem2 = FinancialSituationMemory("inc", config=self.config) + self.assertEqual(len(mem2.documents), 2) + + def test_clear_persists_empty(self): + mem = FinancialSituationMemory("clr", config=self.config) + mem.add_situations(SAMPLE_DATA) + mem.clear() + + del mem + mem2 = FinancialSituationMemory("clr", config=self.config) + self.assertEqual(len(mem2.documents), 0) + self.assertEqual(len(mem2.recommendations), 0) + + def test_bm25_rebuilt_after_load(self): + """BM25 index should work after loading from disk.""" + mem1 = FinancialSituationMemory("bm25", config=self.config) + mem1.add_situations(SAMPLE_DATA) + + del mem1 + mem2 = FinancialSituationMemory("bm25", config=self.config) + results = mem2.get_memories("rising interest rates inflation", n_matches=1) + self.assertEqual(len(results), 1) + self.assertIn("inflation", results[0]["matched_situation"].lower()) + + def test_json_file_content_valid(self): + """Persisted JSON has the expected schema.""" + mem = FinancialSituationMemory("schema", config=self.config) + mem.add_situations(SAMPLE_DATA[:2]) + + fp = pathlib.Path(self.tmpdir) / "schema.json" + data = json.loads(fp.read_text(encoding="utf-8")) + self.assertIn("situations", data) + self.assertIn("recommendations", data) + self.assertEqual(len(data["situations"]), 2) + self.assertEqual(len(data["recommendations"]), 2) + + def test_unicode_round_trip(self): + """Non-ASCII data (e.g. CJK) survives persistence.""" + mem = FinancialSituationMemory("uni", config=self.config) + mem.add_situations([("通胀上升,利率攀升", "考虑防御性板块")]) + + del mem + mem2 = FinancialSituationMemory("uni", config=self.config) + self.assertEqual(mem2.documents[0], "通胀上升,利率攀升") + self.assertEqual(mem2.recommendations[0], "考虑防御性板块") + + +class TestEdgeCases(unittest.TestCase): + """Graceful handling of corrupt / missing files.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.config = {"memory_persist_dir": self.tmpdir} + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_corrupt_json_starts_fresh(self): + fp = pathlib.Path(self.tmpdir) / "bad.json" + fp.write_text("NOT VALID JSON {{{", encoding="utf-8") + mem = FinancialSituationMemory("bad", config=self.config) + self.assertEqual(len(mem.documents), 0, "Should start fresh on corrupt file") + + def test_missing_keys_starts_fresh(self): + fp = pathlib.Path(self.tmpdir) / "partial.json" + fp.write_text(json.dumps({"unrelated": True}), encoding="utf-8") + mem = FinancialSituationMemory("partial", config=self.config) + self.assertEqual(len(mem.documents), 0) + + def test_mismatched_lengths_truncates(self): + """If situations and recommendations have different lengths, zip truncates.""" + fp = pathlib.Path(self.tmpdir) / "mismatch.json" + fp.write_text(json.dumps({ + "situations": ["s1", "s2", "s3"], + "recommendations": ["r1", "r2"], + }), encoding="utf-8") + mem = FinancialSituationMemory("mismatch", config=self.config) + self.assertEqual(len(mem.documents), 2) + self.assertEqual(len(mem.recommendations), 2) + + def test_parent_dirs_created(self): + nested = os.path.join(self.tmpdir, "a", "b", "c") + config = {"memory_persist_dir": nested} + mem = FinancialSituationMemory("nested", config=config) + mem.add_situations(SAMPLE_DATA[:1]) + self.assertTrue(pathlib.Path(nested).is_dir()) + self.assertTrue((pathlib.Path(nested) / "nested.json").exists()) + + def test_tilde_expansion(self): + """~ in persist dir is expanded.""" + config = {"memory_persist_dir": "~/.__test_tradingagents_mem__"} + try: + mem = FinancialSituationMemory("tilde", config=config) + self.assertFalse(str(mem._persist_path).startswith("~")) + self.assertIn(os.path.expanduser("~"), str(mem._persist_path)) + finally: + import shutil + shutil.rmtree( + os.path.expanduser("~/.__test_tradingagents_mem__"), + ignore_errors=True, + ) + + +class TestMultipleInstances(unittest.TestCase): + """Multiple memory instances in the same directory don't collide.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.config = {"memory_persist_dir": self.tmpdir} + + def tearDown(self): + import shutil + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_independent_files(self): + bull = FinancialSituationMemory("bull_memory", config=self.config) + bear = FinancialSituationMemory("bear_memory", config=self.config) + + bull.add_situations(SAMPLE_DATA[:1]) + bear.add_situations(SAMPLE_DATA[1:2]) + + del bull, bear + + bull2 = FinancialSituationMemory("bull_memory", config=self.config) + bear2 = FinancialSituationMemory("bear_memory", config=self.config) + + self.assertEqual(len(bull2.documents), 1) + self.assertEqual(len(bear2.documents), 1) + self.assertNotEqual(bull2.documents[0], bear2.documents[0]) + + +class TestDefaultConfig(unittest.TestCase): + """Verify the default config includes the new key.""" + + def test_key_exists(self): + self.assertIn("memory_persist_dir", DEFAULT_CONFIG) + + def test_default_is_none(self): + self.assertIsNone( + DEFAULT_CONFIG["memory_persist_dir"], + "Default should be None (RAM-only) for backward compatibility", + ) + + +class TestTradingGraphIntegration(unittest.TestCase): + """Verify TradingAgentsGraph passes config to all memory instances. + + This is a source-code audit test — it does NOT instantiate the graph + (which requires LLM API keys and heavy dependencies), but instead + inspects the constructor to confirm the config dict is forwarded. + """ + + def test_config_forwarded_to_memory_constructors(self): + """All five FinancialSituationMemory() calls receive self.config.""" + import inspect + from tradingagents.graph.trading_graph import TradingAgentsGraph + + source = inspect.getsource(TradingAgentsGraph.__init__) + + memory_names = [ + "bull_memory", + "bear_memory", + "trader_memory", + "invest_judge_memory", + "portfolio_manager_memory", + ] + + for name in memory_names: + # Expect: FinancialSituationMemory("name", self.config) + self.assertIn( + f'FinancialSituationMemory("{name}", self.config)', + source, + f"TradingAgentsGraph.__init__ should pass self.config to {name}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tradingagents/agents/utils/memory.py b/tradingagents/agents/utils/memory.py index 2aefa7a3..94982d21 100644 --- a/tradingagents/agents/utils/memory.py +++ b/tradingagents/agents/utils/memory.py @@ -2,28 +2,94 @@ Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls, no token limits, works offline with any LLM provider. + +Supports optional JSON file persistence via config["memory_persist_dir"]. +When set, memories survive process restarts; when unset, behaves as +before (RAM-only). """ +import json +import pathlib + from rank_bm25 import BM25Okapi from typing import List, Tuple import re class FinancialSituationMemory: - """Memory system for storing and retrieving financial situations using BM25.""" + """Memory system for storing and retrieving financial situations using BM25. + + Persistence + ----------- + Pass ``config={"memory_persist_dir": "/some/path"}`` to enable JSON + file persistence. Each memory instance writes a + ``/.json`` file that is loaded on + construction and updated on every ``add_situations`` / ``clear`` + call. When ``memory_persist_dir`` is *not* set (the default), + the class behaves identically to the original RAM-only version. + """ def __init__(self, name: str, config: dict = None): """Initialize the memory system. Args: name: Name identifier for this memory instance - config: Configuration dict (kept for API compatibility, not used for BM25) + config: Configuration dict. Recognises the key + ``memory_persist_dir`` — when set to a directory path, + memories are persisted to ``/.json``. """ self.name = name self.documents: List[str] = [] self.recommendations: List[str] = [] self.bm25 = None + # Resolve persistence path (may be None → RAM-only) + self._persist_path: pathlib.Path | None = None + if config and config.get("memory_persist_dir"): + d = pathlib.Path(config["memory_persist_dir"]).expanduser() + d.mkdir(parents=True, exist_ok=True) + self._persist_path = d / f"{name}.json" + + # Load previously persisted memories (no-op when path is None) + self._load() + + # ------------------------------------------------------------------ + # Persistence helpers + # ------------------------------------------------------------------ + + def _load(self): + """Load previously persisted memories from disk (no-op when RAM-only).""" + if self._persist_path is None or not self._persist_path.exists(): + return + try: + data = json.loads(self._persist_path.read_text(encoding="utf-8")) + situations = data.get("situations", []) + recommendations = data.get("recommendations", []) + # Pair them up; ignore mismatched trailing entries + pairs = list(zip(situations, recommendations)) + if pairs: + self.add_situations(pairs) + except (json.JSONDecodeError, OSError): + # Corrupt / unreadable file — start fresh + pass + + def _save(self): + """Persist current memories to disk (no-op when RAM-only).""" + if self._persist_path is None: + return + payload = { + "situations": list(self.documents), + "recommendations": list(self.recommendations), + } + # Atomic-ish write: write to temp then rename + tmp = self._persist_path.with_suffix(".tmp") + tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + tmp.replace(self._persist_path) + + # ------------------------------------------------------------------ + # Tokenisation & indexing + # ------------------------------------------------------------------ + def _tokenize(self, text: str) -> List[str]: """Tokenize text for BM25 indexing. @@ -54,6 +120,9 @@ class FinancialSituationMemory: # Rebuild BM25 index with new documents self._rebuild_index() + # Persist to disk (no-op when RAM-only) + self._save() + def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]: """Find matching recommendations using BM25 similarity. @@ -96,6 +165,7 @@ class FinancialSituationMemory: self.documents = [] self.recommendations = [] self.bm25 = None + self._save() # Persist the empty state if __name__ == "__main__": diff --git a/tradingagents/default_config.py b/tradingagents/default_config.py index a9b75e4b..497f3d27 100644 --- a/tradingagents/default_config.py +++ b/tradingagents/default_config.py @@ -18,6 +18,8 @@ DEFAULT_CONFIG = { # Output language for analyst reports and final decision # Internal agent debate stays in English for reasoning quality "output_language": "English", + # Memory persistence (None = RAM-only, path = JSON file persistence) + "memory_persist_dir": None, # Debate and discussion settings "max_debate_rounds": 1, "max_risk_discuss_rounds": 1,