feat: persist FinancialSituationMemory to disk (#563)

This commit is contained in:
Zhigong Liu 2026-04-18 21:13:14 -04:00
parent 8536ccacdd
commit bf7d27e0a9
3 changed files with 239 additions and 1 deletions

View File

@ -0,0 +1,166 @@
"""Tests for FinancialSituationMemory persistence (issue #563)."""
import json
import pytest
from pathlib import Path
from tradingagents.agents.utils.memory import FinancialSituationMemory
@pytest.fixture
def persist_dir(tmp_path):
return str(tmp_path / "memory")
def make_config(persist_dir):
return {"memory_persist_dir": persist_dir}
# ---------------------------------------------------------------------------
# Persistence: data survives a fresh instance
# ---------------------------------------------------------------------------
def test_data_survives_restart(persist_dir):
"""Documents and recommendations loaded by a new instance after save."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([("situation A", "recommendation A")])
m2 = FinancialSituationMemory("test", make_config(persist_dir))
assert m2.documents == ["situation A"]
assert m2.recommendations == ["recommendation A"]
def test_multiple_entries_survive_restart(persist_dir):
"""All entries are preserved across instances."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([
("situation A", "rec A"),
("situation B", "rec B"),
])
m2 = FinancialSituationMemory("test", make_config(persist_dir))
assert len(m2.documents) == 2
assert len(m2.recommendations) == 2
def test_bm25_index_rebuilt_on_load(persist_dir):
"""BM25 index is functional after loading from disk."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([("rising interest rates inflation", "reduce duration")])
m2 = FinancialSituationMemory("test", make_config(persist_dir))
results = m2.get_memories("inflation rate rising", n_matches=1)
assert len(results) == 1
assert results[0]["recommendation"] == "reduce duration"
# ---------------------------------------------------------------------------
# RAM-only mode: no persist_dir → no file written
# ---------------------------------------------------------------------------
def test_no_persist_dir_no_file(tmp_path):
"""When memory_persist_dir is absent, no persist path is set and data stays in RAM."""
m = FinancialSituationMemory("test", config={})
m.add_situations([("situation", "rec")])
assert m._persist_path is None
# Data is still accessible in RAM
assert m.documents == ["situation"]
assert m.recommendations == ["rec"]
# Nothing was written to disk
assert list(tmp_path.iterdir()) == []
def test_none_config_no_file(tmp_path):
"""When config is None (default), no persist path is set and data stays in RAM."""
m = FinancialSituationMemory("test")
m.add_situations([("situation", "rec")])
assert m._persist_path is None
# Data is still accessible in RAM
assert m.documents == ["situation"]
assert m.recommendations == ["rec"]
# Nothing was written to disk
assert list(tmp_path.iterdir()) == []
# ---------------------------------------------------------------------------
# Instance isolation: separate names → separate files
# ---------------------------------------------------------------------------
def test_separate_names_separate_files(persist_dir):
"""Two instances with different names do not share state."""
bull = FinancialSituationMemory("bull_memory", make_config(persist_dir))
bear = FinancialSituationMemory("bear_memory", make_config(persist_dir))
bull.add_situations([("bull situation", "buy")])
bear.add_situations([("bear situation", "sell")])
bull2 = FinancialSituationMemory("bull_memory", make_config(persist_dir))
bear2 = FinancialSituationMemory("bear_memory", make_config(persist_dir))
assert bull2.documents == ["bull situation"]
assert bear2.documents == ["bear situation"]
files = {f.name for f in Path(persist_dir).iterdir()}
assert "bull_memory.json" in files
assert "bear_memory.json" in files
# ---------------------------------------------------------------------------
# clear() persists the empty state
# ---------------------------------------------------------------------------
def test_clear_persists(persist_dir):
"""After clear(), a new instance starts empty rather than reloading old data."""
m1 = FinancialSituationMemory("test", make_config(persist_dir))
m1.add_situations([("situation", "rec")])
m1.clear()
m2 = FinancialSituationMemory("test", make_config(persist_dir))
assert m2.documents == []
assert m2.bm25 is None
# ---------------------------------------------------------------------------
# Resilience: corrupt or mismatched files fall back to empty memory
# ---------------------------------------------------------------------------
def test_corrupt_json_falls_back_to_empty(persist_dir):
"""A corrupt JSON file is ignored and memory starts empty (no crash)."""
Path(persist_dir).mkdir(parents=True, exist_ok=True)
(Path(persist_dir) / "test.json").write_text("not valid json", encoding="utf-8")
m = FinancialSituationMemory("test", make_config(persist_dir))
assert m.documents == []
assert m.bm25 is None
def test_mismatched_lengths_falls_back_to_empty(persist_dir):
"""A file with mismatched documents/recommendations lengths is ignored."""
Path(persist_dir).mkdir(parents=True, exist_ok=True)
(Path(persist_dir) / "test.json").write_text(
json.dumps({"documents": ["a", "b"], "recommendations": ["r1"]}),
encoding="utf-8",
)
m = FinancialSituationMemory("test", make_config(persist_dir))
assert m.documents == []
assert m.bm25 is None
# ---------------------------------------------------------------------------
# File format: JSON is human-readable and well-formed
# ---------------------------------------------------------------------------
def test_file_is_valid_json(persist_dir):
"""The persisted file is valid JSON with expected top-level keys."""
m = FinancialSituationMemory("test", make_config(persist_dir))
m.add_situations([("situation", "rec")])
file_path = Path(persist_dir) / "test.json"
assert file_path.exists()
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
assert "documents" in data
assert "recommendations" in data
assert isinstance(data["documents"], list)
assert isinstance(data["recommendations"], list)

View File

@ -4,6 +4,12 @@ Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider. no token limits, works offline with any LLM provider.
""" """
import json
import logging
import tempfile
from pathlib import Path
logger = logging.getLogger(__name__)
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
from typing import List, Tuple from typing import List, Tuple
import re import re
@ -17,12 +23,75 @@ class FinancialSituationMemory:
Args: Args:
name: Name identifier for this memory instance name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25) config: Configuration dict. If config contains a non-empty
``memory_persist_dir`` key, documents and recommendations
are loaded from (and saved to) that directory.
""" """
self.name = name self.name = name
self.documents: List[str] = [] self.documents: List[str] = []
self.recommendations: List[str] = [] self.recommendations: List[str] = []
self.bm25 = None self.bm25 = None
self._persist_path = None
if config:
persist_dir = config.get("memory_persist_dir")
if persist_dir:
self._persist_path = Path(persist_dir) / f"{name}.json"
self._load()
def _load(self):
"""Load documents and recommendations from disk if the persist file exists."""
if not (self._persist_path and self._persist_path.exists()):
return
try:
with open(self._persist_path, "r", encoding="utf-8") as f:
data = json.load(f)
docs = data.get("documents", [])
recs = data.get("recommendations", [])
if len(docs) != len(recs):
logger.warning(
"Memory file %s is corrupt (documents/recommendations length mismatch). "
"Starting with empty memory.",
self._persist_path,
)
return
self.documents = docs
self.recommendations = recs
if self.documents:
self._rebuild_index()
except (json.JSONDecodeError, OSError) as exc:
logger.warning(
"Could not load memory from %s (%s). Starting with empty memory.",
self._persist_path,
exc,
)
def _save(self):
"""Persist documents and recommendations to disk atomically."""
if not self._persist_path:
return
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
try:
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=self._persist_path.parent,
delete=False,
suffix=".tmp",
) as tmp:
json.dump(
{
"documents": self.documents,
"recommendations": self.recommendations,
},
tmp,
indent=2,
ensure_ascii=False,
)
tmp_path = Path(tmp.name)
tmp_path.replace(self._persist_path)
except OSError as exc:
logger.warning("Could not save memory to %s (%s).", self._persist_path, exc)
def _tokenize(self, text: str) -> List[str]: def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing. """Tokenize text for BM25 indexing.
@ -53,6 +122,7 @@ class FinancialSituationMemory:
# Rebuild BM25 index with new documents # Rebuild BM25 index with new documents
self._rebuild_index() self._rebuild_index()
self._save()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]: def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity. """Find matching recommendations using BM25 similarity.
@ -96,6 +166,7 @@ class FinancialSituationMemory:
self.documents = [] self.documents = []
self.recommendations = [] self.recommendations = []
self.bm25 = None self.bm25 = None
self._save()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,6 +6,7 @@ DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")), "project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")), "results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")), "data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
"memory_persist_dir": os.path.join(_TRADINGAGENTS_HOME, "memory"),
# LLM settings # LLM settings
"llm_provider": "openai", "llm_provider": "openai",
"deep_think_llm": "gpt-5.4", "deep_think_llm": "gpt-5.4",