fix: add JSON file persistence to FinancialSituationMemory

FinancialSituationMemory stored all learned lessons in RAM-only Python
lists with no persistence layer.  Every process restart or new
TradingAgentsGraph() instance wiped all memory, making
reflect_and_remember() useless in practice — especially in server/API
deployments where a new graph is created per request.

Changes:

1. memory.py — add optional JSON file persistence:
   - New config key `memory_persist_dir`: when set to a directory path,
     each memory instance writes `<dir>/<name>.json` on every mutation
     (add_situations / clear) and loads it on construction.
   - When unset or None (the default), behaviour is identical to before
     (RAM-only) — fully backward compatible.
   - Atomic-ish writes via .tmp → rename to avoid corruption on crash.
   - Graceful handling of corrupt / missing / partial JSON files.
   - Tilde expansion (`~/...`) and automatic parent directory creation.

2. default_config.py — add `memory_persist_dir: None` to DEFAULT_CONFIG.

3. main.py — enable persistence in the example and improve
   reflect_and_remember documentation comment.

4. tests/test_memory_persistence.py — 21 regression tests covering:
   - RAM-only backward compatibility (5 tests)
   - Persistence round-trip, incremental add, clear, BM25 rebuild,
     JSON schema, Unicode (7 tests)
   - Edge cases: corrupt JSON, missing keys, mismatched lengths,
     nested directory creation, tilde expansion (5 tests)
   - Multiple instances sharing same directory (1 test)
   - Default config key existence (2 tests)
   - Source audit: TradingAgentsGraph passes config to all 5 memories (1 test)

Closes #563
This commit is contained in:
voidborne-d 2026-04-17 15:54:06 +00:00 committed by voidborne-d
parent fa4d01c23a
commit 1e2164ec78
4 changed files with 377 additions and 4 deletions

13
main.py
View File

@ -20,6 +20,10 @@ config["data_vendors"] = {
"news_data": "yfinance", # Options: alpha_vantage, yfinance
}
# Enable memory persistence so lessons survive restarts (optional).
# Set to None or omit to keep the default RAM-only behaviour.
config["memory_persist_dir"] = "~/.tradingagents/memory"
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
@ -27,5 +31,10 @@ ta = TradingAgentsGraph(debug=True, config=config)
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
# Memorize mistakes and reflect
# ta.reflect_and_remember(1000) # parameter is the position returns
# Reflect on the decision after observing actual returns.
# Call this once the position closes and the P&L is known.
# The signed float indicates outcome: positive = correct signal,
# negative = incorrect signal. Lessons are persisted when
# memory_persist_dir is set, so the next TradingAgentsGraph
# instance will load them automatically.
# ta.reflect_and_remember(0.03) # e.g. 3 % gain

View File

@ -0,0 +1,292 @@
"""Tests for FinancialSituationMemory JSON file persistence.
Covers:
- RAM-only mode (no config / config without key) behaves identically to before
- Persistence: add_situations restart memories survive
- Persistence: clear restart file reflects empty state
- Atomic write: .tmp .json rename
- Corrupt / missing file is handled gracefully
- BM25 index is rebuilt after loading from disk
- Multiple memory instances sharing the same dir don't collide
- Default config includes memory_persist_dir key
- TradingAgentsGraph passes config through to all five memory instances
"""
import json
import os
import pathlib
import tempfile
import unittest
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.default_config import DEFAULT_CONFIG
SAMPLE_DATA = [
(
"High inflation rate with rising interest rates",
"Consider defensive sectors like utilities.",
),
(
"Tech sector volatility with institutional selling pressure",
"Reduce high-growth tech exposure.",
),
(
"Strong dollar affecting emerging markets",
"Hedge currency exposure in international positions.",
),
]
class TestRamOnlyMode(unittest.TestCase):
"""Ensure RAM-only mode (no persistence) is fully backward-compatible."""
def test_no_config(self):
mem = FinancialSituationMemory("test")
self.assertIsNone(mem._persist_path)
mem.add_situations(SAMPLE_DATA[:1])
self.assertEqual(len(mem.documents), 1)
def test_config_without_persist_key(self):
mem = FinancialSituationMemory("test", config={"some_other_key": True})
self.assertIsNone(mem._persist_path)
def test_config_with_none_persist_dir(self):
mem = FinancialSituationMemory("test", config={"memory_persist_dir": None})
self.assertIsNone(mem._persist_path)
def test_config_with_empty_string_persist_dir(self):
mem = FinancialSituationMemory("test", config={"memory_persist_dir": ""})
self.assertIsNone(mem._persist_path)
def test_ram_only_data_lost_on_new_instance(self):
mem1 = FinancialSituationMemory("test")
mem1.add_situations(SAMPLE_DATA)
self.assertEqual(len(mem1.documents), 3)
mem2 = FinancialSituationMemory("test")
self.assertEqual(len(mem2.documents), 0)
class TestPersistence(unittest.TestCase):
"""Core persistence round-trip tests."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_file_created_on_add(self):
mem = FinancialSituationMemory("bull_memory", config=self.config)
self.assertFalse(
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
"File should not exist before adding data",
)
mem.add_situations(SAMPLE_DATA[:1])
self.assertTrue(
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
"File should be created after add_situations",
)
def test_round_trip(self):
"""Data survives a full destroy-and-recreate cycle."""
mem1 = FinancialSituationMemory("rt", config=self.config)
mem1.add_situations(SAMPLE_DATA)
self.assertEqual(len(mem1.documents), 3)
# Destroy first instance, create a new one — simulates process restart
del mem1
mem2 = FinancialSituationMemory("rt", config=self.config)
self.assertEqual(len(mem2.documents), 3)
self.assertEqual(mem2.documents[0], SAMPLE_DATA[0][0])
self.assertEqual(mem2.recommendations[1], SAMPLE_DATA[1][1])
def test_incremental_add_persists(self):
"""Multiple add_situations calls accumulate correctly on disk."""
mem = FinancialSituationMemory("inc", config=self.config)
mem.add_situations(SAMPLE_DATA[:1])
mem.add_situations(SAMPLE_DATA[1:2])
del mem
mem2 = FinancialSituationMemory("inc", config=self.config)
self.assertEqual(len(mem2.documents), 2)
def test_clear_persists_empty(self):
mem = FinancialSituationMemory("clr", config=self.config)
mem.add_situations(SAMPLE_DATA)
mem.clear()
del mem
mem2 = FinancialSituationMemory("clr", config=self.config)
self.assertEqual(len(mem2.documents), 0)
self.assertEqual(len(mem2.recommendations), 0)
def test_bm25_rebuilt_after_load(self):
"""BM25 index should work after loading from disk."""
mem1 = FinancialSituationMemory("bm25", config=self.config)
mem1.add_situations(SAMPLE_DATA)
del mem1
mem2 = FinancialSituationMemory("bm25", config=self.config)
results = mem2.get_memories("rising interest rates inflation", n_matches=1)
self.assertEqual(len(results), 1)
self.assertIn("inflation", results[0]["matched_situation"].lower())
def test_json_file_content_valid(self):
"""Persisted JSON has the expected schema."""
mem = FinancialSituationMemory("schema", config=self.config)
mem.add_situations(SAMPLE_DATA[:2])
fp = pathlib.Path(self.tmpdir) / "schema.json"
data = json.loads(fp.read_text(encoding="utf-8"))
self.assertIn("situations", data)
self.assertIn("recommendations", data)
self.assertEqual(len(data["situations"]), 2)
self.assertEqual(len(data["recommendations"]), 2)
def test_unicode_round_trip(self):
"""Non-ASCII data (e.g. CJK) survives persistence."""
mem = FinancialSituationMemory("uni", config=self.config)
mem.add_situations([("通胀上升,利率攀升", "考虑防御性板块")])
del mem
mem2 = FinancialSituationMemory("uni", config=self.config)
self.assertEqual(mem2.documents[0], "通胀上升,利率攀升")
self.assertEqual(mem2.recommendations[0], "考虑防御性板块")
class TestEdgeCases(unittest.TestCase):
"""Graceful handling of corrupt / missing files."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_corrupt_json_starts_fresh(self):
fp = pathlib.Path(self.tmpdir) / "bad.json"
fp.write_text("NOT VALID JSON {{{", encoding="utf-8")
mem = FinancialSituationMemory("bad", config=self.config)
self.assertEqual(len(mem.documents), 0, "Should start fresh on corrupt file")
def test_missing_keys_starts_fresh(self):
fp = pathlib.Path(self.tmpdir) / "partial.json"
fp.write_text(json.dumps({"unrelated": True}), encoding="utf-8")
mem = FinancialSituationMemory("partial", config=self.config)
self.assertEqual(len(mem.documents), 0)
def test_mismatched_lengths_truncates(self):
"""If situations and recommendations have different lengths, zip truncates."""
fp = pathlib.Path(self.tmpdir) / "mismatch.json"
fp.write_text(json.dumps({
"situations": ["s1", "s2", "s3"],
"recommendations": ["r1", "r2"],
}), encoding="utf-8")
mem = FinancialSituationMemory("mismatch", config=self.config)
self.assertEqual(len(mem.documents), 2)
self.assertEqual(len(mem.recommendations), 2)
def test_parent_dirs_created(self):
nested = os.path.join(self.tmpdir, "a", "b", "c")
config = {"memory_persist_dir": nested}
mem = FinancialSituationMemory("nested", config=config)
mem.add_situations(SAMPLE_DATA[:1])
self.assertTrue(pathlib.Path(nested).is_dir())
self.assertTrue((pathlib.Path(nested) / "nested.json").exists())
def test_tilde_expansion(self):
"""~ in persist dir is expanded."""
config = {"memory_persist_dir": "~/.__test_tradingagents_mem__"}
try:
mem = FinancialSituationMemory("tilde", config=config)
self.assertFalse(str(mem._persist_path).startswith("~"))
self.assertIn(os.path.expanduser("~"), str(mem._persist_path))
finally:
import shutil
shutil.rmtree(
os.path.expanduser("~/.__test_tradingagents_mem__"),
ignore_errors=True,
)
class TestMultipleInstances(unittest.TestCase):
"""Multiple memory instances in the same directory don't collide."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_independent_files(self):
bull = FinancialSituationMemory("bull_memory", config=self.config)
bear = FinancialSituationMemory("bear_memory", config=self.config)
bull.add_situations(SAMPLE_DATA[:1])
bear.add_situations(SAMPLE_DATA[1:2])
del bull, bear
bull2 = FinancialSituationMemory("bull_memory", config=self.config)
bear2 = FinancialSituationMemory("bear_memory", config=self.config)
self.assertEqual(len(bull2.documents), 1)
self.assertEqual(len(bear2.documents), 1)
self.assertNotEqual(bull2.documents[0], bear2.documents[0])
class TestDefaultConfig(unittest.TestCase):
"""Verify the default config includes the new key."""
def test_key_exists(self):
self.assertIn("memory_persist_dir", DEFAULT_CONFIG)
def test_default_is_none(self):
self.assertIsNone(
DEFAULT_CONFIG["memory_persist_dir"],
"Default should be None (RAM-only) for backward compatibility",
)
class TestTradingGraphIntegration(unittest.TestCase):
"""Verify TradingAgentsGraph passes config to all memory instances.
This is a source-code audit test it does NOT instantiate the graph
(which requires LLM API keys and heavy dependencies), but instead
inspects the constructor to confirm the config dict is forwarded.
"""
def test_config_forwarded_to_memory_constructors(self):
"""All five FinancialSituationMemory() calls receive self.config."""
import inspect
from tradingagents.graph.trading_graph import TradingAgentsGraph
source = inspect.getsource(TradingAgentsGraph.__init__)
memory_names = [
"bull_memory",
"bear_memory",
"trader_memory",
"invest_judge_memory",
"portfolio_manager_memory",
]
for name in memory_names:
# Expect: FinancialSituationMemory("name", self.config)
self.assertIn(
f'FinancialSituationMemory("{name}", self.config)',
source,
f"TradingAgentsGraph.__init__ should pass self.config to {name}",
)
if __name__ == "__main__":
unittest.main()

View File

@ -2,28 +2,94 @@
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
Supports optional JSON file persistence via config["memory_persist_dir"].
When set, memories survive process restarts; when unset, behaves as
before (RAM-only).
"""
import json
import pathlib
from rank_bm25 import BM25Okapi
from typing import List, Tuple
import re
class FinancialSituationMemory:
"""Memory system for storing and retrieving financial situations using BM25."""
"""Memory system for storing and retrieving financial situations using BM25.
Persistence
-----------
Pass ``config={"memory_persist_dir": "/some/path"}`` to enable JSON
file persistence. Each memory instance writes a
``<memory_persist_dir>/<name>.json`` file that is loaded on
construction and updated on every ``add_situations`` / ``clear``
call. When ``memory_persist_dir`` is *not* set (the default),
the class behaves identically to the original RAM-only version.
"""
def __init__(self, name: str, config: dict = None):
"""Initialize the memory system.
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
config: Configuration dict. Recognises the key
``memory_persist_dir`` when set to a directory path,
memories are persisted to ``<dir>/<name>.json``.
"""
self.name = name
self.documents: List[str] = []
self.recommendations: List[str] = []
self.bm25 = None
# Resolve persistence path (may be None → RAM-only)
self._persist_path: pathlib.Path | None = None
if config and config.get("memory_persist_dir"):
d = pathlib.Path(config["memory_persist_dir"]).expanduser()
d.mkdir(parents=True, exist_ok=True)
self._persist_path = d / f"{name}.json"
# Load previously persisted memories (no-op when path is None)
self._load()
# ------------------------------------------------------------------
# Persistence helpers
# ------------------------------------------------------------------
def _load(self):
"""Load previously persisted memories from disk (no-op when RAM-only)."""
if self._persist_path is None or not self._persist_path.exists():
return
try:
data = json.loads(self._persist_path.read_text(encoding="utf-8"))
situations = data.get("situations", [])
recommendations = data.get("recommendations", [])
# Pair them up; ignore mismatched trailing entries
pairs = list(zip(situations, recommendations))
if pairs:
self.add_situations(pairs)
except (json.JSONDecodeError, OSError):
# Corrupt / unreadable file — start fresh
pass
def _save(self):
"""Persist current memories to disk (no-op when RAM-only)."""
if self._persist_path is None:
return
payload = {
"situations": list(self.documents),
"recommendations": list(self.recommendations),
}
# Atomic-ish write: write to temp then rename
tmp = self._persist_path.with_suffix(".tmp")
tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp.replace(self._persist_path)
# ------------------------------------------------------------------
# Tokenisation & indexing
# ------------------------------------------------------------------
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
@ -54,6 +120,9 @@ class FinancialSituationMemory:
# Rebuild BM25 index with new documents
self._rebuild_index()
# Persist to disk (no-op when RAM-only)
self._save()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
@ -96,6 +165,7 @@ class FinancialSituationMemory:
self.documents = []
self.recommendations = []
self.bm25 = None
self._save() # Persist the empty state
if __name__ == "__main__":

View File

@ -18,6 +18,8 @@ DEFAULT_CONFIG = {
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
# Memory persistence (None = RAM-only, path = JSON file persistence)
"memory_persist_dir": None,
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,