221 lines
8.3 KiB
Python
221 lines
8.3 KiB
Python
"""Financial situation memory using BM25 for lexical similarity matching.
|
|
|
|
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
|
|
no token limits, works offline with any LLM provider.
|
|
|
|
Supports optional JSON file persistence via config["memory_persist_dir"].
|
|
When set, memories survive process restarts; when unset, behaves as
|
|
before (RAM-only).
|
|
"""
|
|
|
|
import json
|
|
import pathlib
|
|
|
|
from rank_bm25 import BM25Okapi
|
|
from typing import List, Tuple
|
|
import re
|
|
|
|
|
|
class FinancialSituationMemory:
|
|
"""Memory system for storing and retrieving financial situations using BM25.
|
|
|
|
Persistence
|
|
-----------
|
|
Pass ``config={"memory_persist_dir": "/some/path"}`` to enable JSON
|
|
file persistence. Each memory instance writes a
|
|
``<memory_persist_dir>/<name>.json`` file that is loaded on
|
|
construction and updated on every ``add_situations`` / ``clear``
|
|
call. When ``memory_persist_dir`` is *not* set (the default),
|
|
the class behaves identically to the original RAM-only version.
|
|
"""
|
|
|
|
def __init__(self, name: str, config: dict = None):
|
|
"""Initialize the memory system.
|
|
|
|
Args:
|
|
name: Name identifier for this memory instance
|
|
config: Configuration dict. Recognises the key
|
|
``memory_persist_dir`` — when set to a directory path,
|
|
memories are persisted to ``<dir>/<name>.json``.
|
|
"""
|
|
self.name = name
|
|
self.documents: List[str] = []
|
|
self.recommendations: List[str] = []
|
|
self.bm25 = None
|
|
|
|
# Resolve persistence path (may be None → RAM-only)
|
|
self._persist_path: pathlib.Path | None = None
|
|
if config and config.get("memory_persist_dir"):
|
|
d = pathlib.Path(config["memory_persist_dir"]).expanduser()
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
self._persist_path = d / f"{name}.json"
|
|
|
|
# Load previously persisted memories (no-op when path is None)
|
|
self._load()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Persistence helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _load(self):
|
|
"""Load previously persisted memories from disk (no-op when RAM-only)."""
|
|
if self._persist_path is None or not self._persist_path.exists():
|
|
return
|
|
try:
|
|
data = json.loads(self._persist_path.read_text(encoding="utf-8"))
|
|
situations = data.get("situations", [])
|
|
recommendations = data.get("recommendations", [])
|
|
# Populate lists directly to avoid redundant _save() call
|
|
for situation, recommendation in zip(situations, recommendations):
|
|
self.documents.append(situation)
|
|
self.recommendations.append(recommendation)
|
|
if self.documents:
|
|
self._rebuild_index()
|
|
except (json.JSONDecodeError, OSError, AttributeError, TypeError):
|
|
# Corrupt / unreadable file — start fresh
|
|
pass
|
|
|
|
def _save(self):
|
|
"""Persist current memories to disk (no-op when RAM-only)."""
|
|
if self._persist_path is None:
|
|
return
|
|
payload = {
|
|
"situations": list(self.documents),
|
|
"recommendations": list(self.recommendations),
|
|
}
|
|
try:
|
|
# Atomic-ish write: write to temp then rename
|
|
tmp = self._persist_path.with_suffix(".tmp")
|
|
tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
tmp.replace(self._persist_path)
|
|
except OSError:
|
|
# Fail gracefully if disk is full or permissions are restricted
|
|
pass
|
|
|
|
# ------------------------------------------------------------------
|
|
# Tokenisation & indexing
|
|
# ------------------------------------------------------------------
|
|
|
|
def _tokenize(self, text: str) -> List[str]:
|
|
"""Tokenize text for BM25 indexing.
|
|
|
|
Simple whitespace + punctuation tokenization with lowercasing.
|
|
"""
|
|
# Lowercase and split on non-alphanumeric characters
|
|
tokens = re.findall(r'\b\w+\b', text.lower())
|
|
return tokens
|
|
|
|
def _rebuild_index(self):
|
|
"""Rebuild the BM25 index after adding documents."""
|
|
if self.documents:
|
|
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
|
|
self.bm25 = BM25Okapi(tokenized_docs)
|
|
else:
|
|
self.bm25 = None
|
|
|
|
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
|
|
"""Add financial situations and their corresponding advice.
|
|
|
|
Args:
|
|
situations_and_advice: List of tuples (situation, recommendation)
|
|
"""
|
|
for situation, recommendation in situations_and_advice:
|
|
self.documents.append(situation)
|
|
self.recommendations.append(recommendation)
|
|
|
|
# Rebuild BM25 index with new documents
|
|
self._rebuild_index()
|
|
|
|
# Persist to disk (no-op when RAM-only)
|
|
self._save()
|
|
|
|
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
|
"""Find matching recommendations using BM25 similarity.
|
|
|
|
Args:
|
|
current_situation: The current financial situation to match against
|
|
n_matches: Number of top matches to return
|
|
|
|
Returns:
|
|
List of dicts with matched_situation, recommendation, and similarity_score
|
|
"""
|
|
if not self.documents or self.bm25 is None:
|
|
return []
|
|
|
|
# Tokenize query
|
|
query_tokens = self._tokenize(current_situation)
|
|
|
|
# Get BM25 scores for all documents
|
|
scores = self.bm25.get_scores(query_tokens)
|
|
|
|
# Get top-n indices sorted by score (descending)
|
|
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
|
|
|
|
# Build results
|
|
results = []
|
|
max_score = float(scores.max()) if len(scores) > 0 and scores.max() > 0 else 1.0
|
|
|
|
for idx in top_indices:
|
|
# Normalize score to 0-1 range for consistency
|
|
normalized_score = scores[idx] / max_score if max_score > 0 else 0
|
|
results.append({
|
|
"matched_situation": self.documents[idx],
|
|
"recommendation": self.recommendations[idx],
|
|
"similarity_score": normalized_score,
|
|
})
|
|
|
|
return results
|
|
|
|
def clear(self):
|
|
"""Clear all stored memories."""
|
|
self.documents = []
|
|
self.recommendations = []
|
|
self.bm25 = None
|
|
self._save() # Persist the empty state
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
matcher = FinancialSituationMemory("test_memory")
|
|
|
|
# Example data
|
|
example_data = [
|
|
(
|
|
"High inflation rate with rising interest rates and declining consumer spending",
|
|
"Consider defensive sectors like consumer staples and utilities. Review fixed-income portfolio duration.",
|
|
),
|
|
(
|
|
"Tech sector showing high volatility with increasing institutional selling pressure",
|
|
"Reduce exposure to high-growth tech stocks. Look for value opportunities in established tech companies with strong cash flows.",
|
|
),
|
|
(
|
|
"Strong dollar affecting emerging markets with increasing forex volatility",
|
|
"Hedge currency exposure in international positions. Consider reducing allocation to emerging market debt.",
|
|
),
|
|
(
|
|
"Market showing signs of sector rotation with rising yields",
|
|
"Rebalance portfolio to maintain target allocations. Consider increasing exposure to sectors benefiting from higher rates.",
|
|
),
|
|
]
|
|
|
|
# Add the example situations and recommendations
|
|
matcher.add_situations(example_data)
|
|
|
|
# Example query
|
|
current_situation = """
|
|
Market showing increased volatility in tech sector, with institutional investors
|
|
reducing positions and rising interest rates affecting growth stock valuations
|
|
"""
|
|
|
|
try:
|
|
recommendations = matcher.get_memories(current_situation, n_matches=2)
|
|
|
|
for i, rec in enumerate(recommendations, 1):
|
|
print(f"\nMatch {i}:")
|
|
print(f"Similarity Score: {rec['similarity_score']:.2f}")
|
|
print(f"Matched Situation: {rec['matched_situation']}")
|
|
print(f"Recommendation: {rec['recommendation']}")
|
|
|
|
except Exception as e:
|
|
print(f"Error during recommendation: {str(e)}")
|