feat: persist FinancialSituationMemory to disk (#563)

This commit is contained in:
Zhigong Liu 2026-04-18 21:13:14 -04:00
parent 8536ccacdd
commit bf7d27e0a9
3 changed files with 239 additions and 1 deletions

View File

@ -0,0 +1,166 @@
"""Tests for FinancialSituationMemory persistence (issue #563)."""
import json
import pytest
from pathlib import Path
from tradingagents.agents.utils.memory import FinancialSituationMemory
@pytest.fixture
def persist_dir(tmp_path):
return str(tmp_path / "memory")
def make_config(persist_dir):
return {"memory_persist_dir": persist_dir}
# ---------------------------------------------------------------------------
# Persistence: data survives a fresh instance
# ---------------------------------------------------------------------------
def test_data_survives_restart(persist_dir):
"""Documents and recommendations loaded by a new instance after save."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([("situation A", "recommendation A")])
m2 = FinancialSituationMemory("test", make_config(persist_dir))
assert m2.documents == ["situation A"]
assert m2.recommendations == ["recommendation A"]
def test_multiple_entries_survive_restart(persist_dir):
"""All entries are preserved across instances."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([
("situation A", "rec A"),
("situation B", "rec B"),
])
m2 = FinancialSituationMemory("test", make_config(persist_dir))
assert len(m2.documents) == 2
assert len(m2.recommendations) == 2
def test_bm25_index_rebuilt_on_load(persist_dir):
"""BM25 index is functional after loading from disk."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([("rising interest rates inflation", "reduce duration")])
m2 = FinancialSituationMemory("test", make_config(persist_dir))
results = m2.get_memories("inflation rate rising", n_matches=1)
assert len(results) == 1
assert results[0]["recommendation"] == "reduce duration"
# ---------------------------------------------------------------------------
# RAM-only mode: no persist_dir → no file written
# ---------------------------------------------------------------------------
def test_no_persist_dir_no_file(tmp_path):
"""When memory_persist_dir is absent, no persist path is set and data stays in RAM."""
m = FinancialSituationMemory("test", config={})
m.add_situations([("situation", "rec")])
assert m._persist_path is None
# Data is still accessible in RAM
assert m.documents == ["situation"]
assert m.recommendations == ["rec"]
# Nothing was written to disk
assert list(tmp_path.iterdir()) == []
def test_none_config_no_file(tmp_path):
"""When config is None (default), no persist path is set and data stays in RAM."""
m = FinancialSituationMemory("test")
m.add_situations([("situation", "rec")])
assert m._persist_path is None
# Data is still accessible in RAM
assert m.documents == ["situation"]
assert m.recommendations == ["rec"]
# Nothing was written to disk
assert list(tmp_path.iterdir()) == []
# ---------------------------------------------------------------------------
# Instance isolation: separate names → separate files
# ---------------------------------------------------------------------------
def test_separate_names_separate_files(persist_dir):
"""Two instances with different names do not share state."""
bull = FinancialSituationMemory("bull_memory", make_config(persist_dir))
bear = FinancialSituationMemory("bear_memory", make_config(persist_dir))
bull.add_situations([("bull situation", "buy")])
bear.add_situations([("bear situation", "sell")])
bull2 = FinancialSituationMemory("bull_memory", make_config(persist_dir))
bear2 = FinancialSituationMemory("bear_memory", make_config(persist_dir))
assert bull2.documents == ["bull situation"]
assert bear2.documents == ["bear situation"]
files = {f.name for f in Path(persist_dir).iterdir()}
assert "bull_memory.json" in files
assert "bear_memory.json" in files
# ---------------------------------------------------------------------------
# clear() persists the empty state
# ---------------------------------------------------------------------------
def test_clear_persists(persist_dir):
"""After clear(), a new instance starts empty rather than reloading old data."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([("situation", "rec")])
m1.clear()
m2 = FinancialSituationMemory("test", make_config(persist_dir))
assert m2.documents == []
assert m2.bm25 is None
# ---------------------------------------------------------------------------
# Resilience: corrupt or mismatched files fall back to empty memory
# ---------------------------------------------------------------------------
def test_corrupt_json_falls_back_to_empty(persist_dir):
"""A corrupt JSON file is ignored and memory starts empty (no crash)."""
Path(persist_dir).mkdir(parents=True, exist_ok=True)
(Path(persist_dir) / "test.json").write_text("not valid json", encoding="utf-8")
m = FinancialSituationMemory("test", make_config(persist_dir))
assert m.documents == []
assert m.bm25 is None
def test_mismatched_lengths_falls_back_to_empty(persist_dir):
"""A file with mismatched documents/recommendations lengths is ignored."""
Path(persist_dir).mkdir(parents=True, exist_ok=True)
(Path(persist_dir) / "test.json").write_text(
json.dumps({"documents": ["a", "b"], "recommendations": ["r1"]}),
encoding="utf-8",
)
m = FinancialSituationMemory("test", make_config(persist_dir))
assert m.documents == []
assert m.bm25 is None
# ---------------------------------------------------------------------------
# File format: JSON is human-readable and well-formed
# ---------------------------------------------------------------------------
def test_file_is_valid_json(persist_dir):
"""The persisted file is valid JSON with expected top-level keys."""
m = FinancialSituationMemory("test", make_config(persist_dir))
m.add_situations([("situation", "rec")])
file_path = Path(persist_dir) / "test.json"
assert file_path.exists()
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
assert "documents" in data
assert "recommendations" in data
assert isinstance(data["documents"], list)
assert isinstance(data["recommendations"], list)

View File

@ -4,6 +4,12 @@ Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
"""
import json
import logging
import tempfile
from pathlib import Path
logger = logging.getLogger(__name__)
from rank_bm25 import BM25Okapi
from typing import List, Tuple
import re
@ -17,12 +23,75 @@ class FinancialSituationMemory:
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
config: Configuration dict. If config contains a non-empty
``memory_persist_dir`` key, documents and recommendations
are loaded from (and saved to) that directory.
"""
self.name = name
self.documents: List[str] = []
self.recommendations: List[str] = []
self.bm25 = None
self._persist_path = None
if config:
persist_dir = config.get("memory_persist_dir")
if persist_dir:
self._persist_path = Path(persist_dir) / f"{name}.json"
self._load()
def _load(self):
"""Load documents and recommendations from disk if the persist file exists."""
if not (self._persist_path and self._persist_path.exists()):
return
try:
with open(self._persist_path, "r", encoding="utf-8") as f:
data = json.load(f)
docs = data.get("documents", [])
recs = data.get("recommendations", [])
if len(docs) != len(recs):
logger.warning(
"Memory file %s is corrupt (documents/recommendations length mismatch). "
"Starting with empty memory.",
self._persist_path,
)
return
self.documents = docs
self.recommendations = recs
if self.documents:
self._rebuild_index()
except (json.JSONDecodeError, OSError) as exc:
logger.warning(
"Could not load memory from %s (%s). Starting with empty memory.",
self._persist_path,
exc,
)
def _save(self):
"""Persist documents and recommendations to disk atomically."""
if not self._persist_path:
return
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
try:
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=self._persist_path.parent,
delete=False,
suffix=".tmp",
) as tmp:
json.dump(
{
"documents": self.documents,
"recommendations": self.recommendations,
},
tmp,
indent=2,
ensure_ascii=False,
)
tmp_path = Path(tmp.name)
tmp_path.replace(self._persist_path)
except OSError as exc:
logger.warning("Could not save memory to %s (%s).", self._persist_path, exc)
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
@ -53,6 +122,7 @@ class FinancialSituationMemory:
# Rebuild BM25 index with new documents
self._rebuild_index()
self._save()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
@ -96,6 +166,7 @@ class FinancialSituationMemory:
self.documents = []
self.recommendations = []
self.bm25 = None
self._save()
if __name__ == "__main__":

View File

@ -6,6 +6,7 @@ DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
"memory_persist_dir": os.path.join(_TRADINGAGENTS_HOME, "memory"),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "gpt-5.4",