feat: persist FinancialSituationMemory to disk (#563)
This commit is contained in:
parent
8536ccacdd
commit
bf7d27e0a9
|
|
@ -0,0 +1,166 @@
|
|||
"""Tests for FinancialSituationMemory persistence (issue #563)."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def persist_dir(tmp_path):
|
||||
return str(tmp_path / "memory")
|
||||
|
||||
|
||||
def make_config(persist_dir):
|
||||
return {"memory_persist_dir": persist_dir}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persistence: data survives a fresh instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_data_survives_restart(persist_dir):
|
||||
"""Documents and recommendations loaded by a new instance after save."""
|
||||
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
m1.add_situations([("situation A", "recommendation A")])
|
||||
|
||||
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
assert m2.documents == ["situation A"]
|
||||
assert m2.recommendations == ["recommendation A"]
|
||||
|
||||
|
||||
def test_multiple_entries_survive_restart(persist_dir):
|
||||
"""All entries are preserved across instances."""
|
||||
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
m1.add_situations([
|
||||
("situation A", "rec A"),
|
||||
("situation B", "rec B"),
|
||||
])
|
||||
|
||||
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
assert len(m2.documents) == 2
|
||||
assert len(m2.recommendations) == 2
|
||||
|
||||
|
||||
def test_bm25_index_rebuilt_on_load(persist_dir):
|
||||
"""BM25 index is functional after loading from disk."""
|
||||
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
m1.add_situations([("rising interest rates inflation", "reduce duration")])
|
||||
|
||||
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
results = m2.get_memories("inflation rate rising", n_matches=1)
|
||||
assert len(results) == 1
|
||||
assert results[0]["recommendation"] == "reduce duration"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RAM-only mode: no persist_dir → no file written
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_no_persist_dir_no_file(tmp_path):
|
||||
"""When memory_persist_dir is absent, no persist path is set and data stays in RAM."""
|
||||
m = FinancialSituationMemory("test", config={})
|
||||
m.add_situations([("situation", "rec")])
|
||||
assert m._persist_path is None
|
||||
# Data is still accessible in RAM
|
||||
assert m.documents == ["situation"]
|
||||
assert m.recommendations == ["rec"]
|
||||
# Nothing was written to disk
|
||||
assert list(tmp_path.iterdir()) == []
|
||||
|
||||
|
||||
def test_none_config_no_file(tmp_path):
|
||||
"""When config is None (default), no persist path is set and data stays in RAM."""
|
||||
m = FinancialSituationMemory("test")
|
||||
m.add_situations([("situation", "rec")])
|
||||
assert m._persist_path is None
|
||||
# Data is still accessible in RAM
|
||||
assert m.documents == ["situation"]
|
||||
assert m.recommendations == ["rec"]
|
||||
# Nothing was written to disk
|
||||
assert list(tmp_path.iterdir()) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Instance isolation: separate names → separate files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_separate_names_separate_files(persist_dir):
|
||||
"""Two instances with different names do not share state."""
|
||||
bull = FinancialSituationMemory("bull_memory", make_config(persist_dir))
|
||||
bear = FinancialSituationMemory("bear_memory", make_config(persist_dir))
|
||||
|
||||
bull.add_situations([("bull situation", "buy")])
|
||||
bear.add_situations([("bear situation", "sell")])
|
||||
|
||||
bull2 = FinancialSituationMemory("bull_memory", make_config(persist_dir))
|
||||
bear2 = FinancialSituationMemory("bear_memory", make_config(persist_dir))
|
||||
|
||||
assert bull2.documents == ["bull situation"]
|
||||
assert bear2.documents == ["bear situation"]
|
||||
|
||||
files = {f.name for f in Path(persist_dir).iterdir()}
|
||||
assert "bull_memory.json" in files
|
||||
assert "bear_memory.json" in files
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# clear() persists the empty state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_clear_persists(persist_dir):
|
||||
"""After clear(), a new instance starts empty rather than reloading old data."""
|
||||
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
m1.add_situations([("situation", "rec")])
|
||||
m1.clear()
|
||||
|
||||
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
assert m2.documents == []
|
||||
assert m2.bm25 is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resilience: corrupt or mismatched files fall back to empty memory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_corrupt_json_falls_back_to_empty(persist_dir):
|
||||
"""A corrupt JSON file is ignored and memory starts empty (no crash)."""
|
||||
Path(persist_dir).mkdir(parents=True, exist_ok=True)
|
||||
(Path(persist_dir) / "test.json").write_text("not valid json", encoding="utf-8")
|
||||
|
||||
m = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
assert m.documents == []
|
||||
assert m.bm25 is None
|
||||
|
||||
|
||||
def test_mismatched_lengths_falls_back_to_empty(persist_dir):
|
||||
"""A file with mismatched documents/recommendations lengths is ignored."""
|
||||
Path(persist_dir).mkdir(parents=True, exist_ok=True)
|
||||
(Path(persist_dir) / "test.json").write_text(
|
||||
json.dumps({"documents": ["a", "b"], "recommendations": ["r1"]}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
m = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
assert m.documents == []
|
||||
assert m.bm25 is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File format: JSON is human-readable and well-formed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_file_is_valid_json(persist_dir):
|
||||
"""The persisted file is valid JSON with expected top-level keys."""
|
||||
m = FinancialSituationMemory("test", make_config(persist_dir))
|
||||
m.add_situations([("situation", "rec")])
|
||||
|
||||
file_path = Path(persist_dir) / "test.json"
|
||||
assert file_path.exists()
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
assert "documents" in data
|
||||
assert "recommendations" in data
|
||||
assert isinstance(data["documents"], list)
|
||||
assert isinstance(data["recommendations"], list)
|
||||
|
|
@ -4,6 +4,12 @@ Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
|
|||
no token limits, works offline with any LLM provider.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from rank_bm25 import BM25Okapi
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
|
|
@ -17,12 +23,75 @@ class FinancialSituationMemory:
|
|||
|
||||
Args:
|
||||
name: Name identifier for this memory instance
|
||||
config: Configuration dict (kept for API compatibility, not used for BM25)
|
||||
config: Configuration dict. If config contains a non-empty
|
||||
``memory_persist_dir`` key, documents and recommendations
|
||||
are loaded from (and saved to) that directory.
|
||||
"""
|
||||
self.name = name
|
||||
self.documents: List[str] = []
|
||||
self.recommendations: List[str] = []
|
||||
self.bm25 = None
|
||||
self._persist_path = None
|
||||
|
||||
if config:
|
||||
persist_dir = config.get("memory_persist_dir")
|
||||
if persist_dir:
|
||||
self._persist_path = Path(persist_dir) / f"{name}.json"
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""Load documents and recommendations from disk if the persist file exists."""
|
||||
if not (self._persist_path and self._persist_path.exists()):
|
||||
return
|
||||
try:
|
||||
with open(self._persist_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
docs = data.get("documents", [])
|
||||
recs = data.get("recommendations", [])
|
||||
if len(docs) != len(recs):
|
||||
logger.warning(
|
||||
"Memory file %s is corrupt (documents/recommendations length mismatch). "
|
||||
"Starting with empty memory.",
|
||||
self._persist_path,
|
||||
)
|
||||
return
|
||||
self.documents = docs
|
||||
self.recommendations = recs
|
||||
if self.documents:
|
||||
self._rebuild_index()
|
||||
except (json.JSONDecodeError, OSError) as exc:
|
||||
logger.warning(
|
||||
"Could not load memory from %s (%s). Starting with empty memory.",
|
||||
self._persist_path,
|
||||
exc,
|
||||
)
|
||||
|
||||
def _save(self):
|
||||
"""Persist documents and recommendations to disk atomically."""
|
||||
if not self._persist_path:
|
||||
return
|
||||
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=self._persist_path.parent,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
json.dump(
|
||||
{
|
||||
"documents": self.documents,
|
||||
"recommendations": self.recommendations,
|
||||
},
|
||||
tmp,
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
tmp_path = Path(tmp.name)
|
||||
tmp_path.replace(self._persist_path)
|
||||
except OSError as exc:
|
||||
logger.warning("Could not save memory to %s (%s).", self._persist_path, exc)
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Tokenize text for BM25 indexing.
|
||||
|
|
@ -53,6 +122,7 @@ class FinancialSituationMemory:
|
|||
|
||||
# Rebuild BM25 index with new documents
|
||||
self._rebuild_index()
|
||||
self._save()
|
||||
|
||||
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
||||
"""Find matching recommendations using BM25 similarity.
|
||||
|
|
@ -96,6 +166,7 @@ class FinancialSituationMemory:
|
|||
self.documents = []
|
||||
self.recommendations = []
|
||||
self.bm25 = None
|
||||
self._save()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ DEFAULT_CONFIG = {
|
|||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||
"memory_persist_dir": os.path.join(_TRADINGAGENTS_HOME, "memory"),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "gpt-5.4",
|
||||
|
|
|
|||
Loading…
Reference in New Issue