feat: persist FinancialSituationMemory to disk (#563)
This commit is contained in:
parent
8536ccacdd
commit
bf7d27e0a9
|
|
@ -0,0 +1,166 @@
|
||||||
|
"""Tests for FinancialSituationMemory persistence (issue #563)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def persist_dir(tmp_path):
|
||||||
|
return str(tmp_path / "memory")
|
||||||
|
|
||||||
|
|
||||||
|
def make_config(persist_dir):
|
||||||
|
return {"memory_persist_dir": persist_dir}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Persistence: data survives a fresh instance
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_data_survives_restart(persist_dir):
|
||||||
|
"""Documents and recommendations loaded by a new instance after save."""
|
||||||
|
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
m1.add_situations([("situation A", "recommendation A")])
|
||||||
|
|
||||||
|
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
assert m2.documents == ["situation A"]
|
||||||
|
assert m2.recommendations == ["recommendation A"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_entries_survive_restart(persist_dir):
|
||||||
|
"""All entries are preserved across instances."""
|
||||||
|
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
m1.add_situations([
|
||||||
|
("situation A", "rec A"),
|
||||||
|
("situation B", "rec B"),
|
||||||
|
])
|
||||||
|
|
||||||
|
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
assert len(m2.documents) == 2
|
||||||
|
assert len(m2.recommendations) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_bm25_index_rebuilt_on_load(persist_dir):
|
||||||
|
"""BM25 index is functional after loading from disk."""
|
||||||
|
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
m1.add_situations([("rising interest rates inflation", "reduce duration")])
|
||||||
|
|
||||||
|
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
results = m2.get_memories("inflation rate rising", n_matches=1)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["recommendation"] == "reduce duration"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RAM-only mode: no persist_dir → no file written
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_no_persist_dir_no_file(tmp_path):
|
||||||
|
"""When memory_persist_dir is absent, no persist path is set and data stays in RAM."""
|
||||||
|
m = FinancialSituationMemory("test", config={})
|
||||||
|
m.add_situations([("situation", "rec")])
|
||||||
|
assert m._persist_path is None
|
||||||
|
# Data is still accessible in RAM
|
||||||
|
assert m.documents == ["situation"]
|
||||||
|
assert m.recommendations == ["rec"]
|
||||||
|
# Nothing was written to disk
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_none_config_no_file(tmp_path):
|
||||||
|
"""When config is None (default), no persist path is set and data stays in RAM."""
|
||||||
|
m = FinancialSituationMemory("test")
|
||||||
|
m.add_situations([("situation", "rec")])
|
||||||
|
assert m._persist_path is None
|
||||||
|
# Data is still accessible in RAM
|
||||||
|
assert m.documents == ["situation"]
|
||||||
|
assert m.recommendations == ["rec"]
|
||||||
|
# Nothing was written to disk
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Instance isolation: separate names → separate files
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_separate_names_separate_files(persist_dir):
|
||||||
|
"""Two instances with different names do not share state."""
|
||||||
|
bull = FinancialSituationMemory("bull_memory", make_config(persist_dir))
|
||||||
|
bear = FinancialSituationMemory("bear_memory", make_config(persist_dir))
|
||||||
|
|
||||||
|
bull.add_situations([("bull situation", "buy")])
|
||||||
|
bear.add_situations([("bear situation", "sell")])
|
||||||
|
|
||||||
|
bull2 = FinancialSituationMemory("bull_memory", make_config(persist_dir))
|
||||||
|
bear2 = FinancialSituationMemory("bear_memory", make_config(persist_dir))
|
||||||
|
|
||||||
|
assert bull2.documents == ["bull situation"]
|
||||||
|
assert bear2.documents == ["bear situation"]
|
||||||
|
|
||||||
|
files = {f.name for f in Path(persist_dir).iterdir()}
|
||||||
|
assert "bull_memory.json" in files
|
||||||
|
assert "bear_memory.json" in files
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# clear() persists the empty state
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_clear_persists(persist_dir):
|
||||||
|
"""After clear(), a new instance starts empty rather than reloading old data."""
|
||||||
|
m1 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
m1.add_situations([("situation", "rec")])
|
||||||
|
m1.clear()
|
||||||
|
|
||||||
|
m2 = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
assert m2.documents == []
|
||||||
|
assert m2.bm25 is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Resilience: corrupt or mismatched files fall back to empty memory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_corrupt_json_falls_back_to_empty(persist_dir):
|
||||||
|
"""A corrupt JSON file is ignored and memory starts empty (no crash)."""
|
||||||
|
Path(persist_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
(Path(persist_dir) / "test.json").write_text("not valid json", encoding="utf-8")
|
||||||
|
|
||||||
|
m = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
assert m.documents == []
|
||||||
|
assert m.bm25 is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_mismatched_lengths_falls_back_to_empty(persist_dir):
|
||||||
|
"""A file with mismatched documents/recommendations lengths is ignored."""
|
||||||
|
Path(persist_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
(Path(persist_dir) / "test.json").write_text(
|
||||||
|
json.dumps({"documents": ["a", "b"], "recommendations": ["r1"]}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
m = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
assert m.documents == []
|
||||||
|
assert m.bm25 is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File format: JSON is human-readable and well-formed
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_file_is_valid_json(persist_dir):
|
||||||
|
"""The persisted file is valid JSON with expected top-level keys."""
|
||||||
|
m = FinancialSituationMemory("test", make_config(persist_dir))
|
||||||
|
m.add_situations([("situation", "rec")])
|
||||||
|
|
||||||
|
file_path = Path(persist_dir) / "test.json"
|
||||||
|
assert file_path.exists()
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert "documents" in data
|
||||||
|
assert "recommendations" in data
|
||||||
|
assert isinstance(data["documents"], list)
|
||||||
|
assert isinstance(data["recommendations"], list)
|
||||||
|
|
@ -4,6 +4,12 @@ Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
|
||||||
no token limits, works offline with any LLM provider.
|
no token limits, works offline with any LLM provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
from rank_bm25 import BM25Okapi
|
from rank_bm25 import BM25Okapi
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import re
|
import re
|
||||||
|
|
@ -17,12 +23,75 @@ class FinancialSituationMemory:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name identifier for this memory instance
|
name: Name identifier for this memory instance
|
||||||
config: Configuration dict (kept for API compatibility, not used for BM25)
|
config: Configuration dict. If config contains a non-empty
|
||||||
|
``memory_persist_dir`` key, documents and recommendations
|
||||||
|
are loaded from (and saved to) that directory.
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.documents: List[str] = []
|
self.documents: List[str] = []
|
||||||
self.recommendations: List[str] = []
|
self.recommendations: List[str] = []
|
||||||
self.bm25 = None
|
self.bm25 = None
|
||||||
|
self._persist_path = None
|
||||||
|
|
||||||
|
if config:
|
||||||
|
persist_dir = config.get("memory_persist_dir")
|
||||||
|
if persist_dir:
|
||||||
|
self._persist_path = Path(persist_dir) / f"{name}.json"
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
"""Load documents and recommendations from disk if the persist file exists."""
|
||||||
|
if not (self._persist_path and self._persist_path.exists()):
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open(self._persist_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
docs = data.get("documents", [])
|
||||||
|
recs = data.get("recommendations", [])
|
||||||
|
if len(docs) != len(recs):
|
||||||
|
logger.warning(
|
||||||
|
"Memory file %s is corrupt (documents/recommendations length mismatch). "
|
||||||
|
"Starting with empty memory.",
|
||||||
|
self._persist_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.documents = docs
|
||||||
|
self.recommendations = recs
|
||||||
|
if self.documents:
|
||||||
|
self._rebuild_index()
|
||||||
|
except (json.JSONDecodeError, OSError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Could not load memory from %s (%s). Starting with empty memory.",
|
||||||
|
self._persist_path,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
"""Persist documents and recommendations to disk atomically."""
|
||||||
|
if not self._persist_path:
|
||||||
|
return
|
||||||
|
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
encoding="utf-8",
|
||||||
|
dir=self._persist_path.parent,
|
||||||
|
delete=False,
|
||||||
|
suffix=".tmp",
|
||||||
|
) as tmp:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"documents": self.documents,
|
||||||
|
"recommendations": self.recommendations,
|
||||||
|
},
|
||||||
|
tmp,
|
||||||
|
indent=2,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
tmp_path = Path(tmp.name)
|
||||||
|
tmp_path.replace(self._persist_path)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning("Could not save memory to %s (%s).", self._persist_path, exc)
|
||||||
|
|
||||||
def _tokenize(self, text: str) -> List[str]:
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
"""Tokenize text for BM25 indexing.
|
"""Tokenize text for BM25 indexing.
|
||||||
|
|
@ -53,6 +122,7 @@ class FinancialSituationMemory:
|
||||||
|
|
||||||
# Rebuild BM25 index with new documents
|
# Rebuild BM25 index with new documents
|
||||||
self._rebuild_index()
|
self._rebuild_index()
|
||||||
|
self._save()
|
||||||
|
|
||||||
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
||||||
"""Find matching recommendations using BM25 similarity.
|
"""Find matching recommendations using BM25 similarity.
|
||||||
|
|
@ -96,6 +166,7 @@ class FinancialSituationMemory:
|
||||||
self.documents = []
|
self.documents = []
|
||||||
self.recommendations = []
|
self.recommendations = []
|
||||||
self.bm25 = None
|
self.bm25 = None
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ DEFAULT_CONFIG = {
|
||||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", os.path.join(_TRADINGAGENTS_HOME, "logs")),
|
||||||
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
"data_cache_dir": os.getenv("TRADINGAGENTS_CACHE_DIR", os.path.join(_TRADINGAGENTS_HOME, "cache")),
|
||||||
|
"memory_persist_dir": os.path.join(_TRADINGAGENTS_HOME, "memory"),
|
||||||
# LLM settings
|
# LLM settings
|
||||||
"llm_provider": "openai",
|
"llm_provider": "openai",
|
||||||
"deep_think_llm": "gpt-5.4",
|
"deep_think_llm": "gpt-5.4",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue