This commit is contained in:
d 🔹 2026-04-17 15:08:54 -04:00 committed by GitHub
commit fbf51db362
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 383 additions and 4 deletions

13
main.py
View File

@ -20,6 +20,10 @@ config["data_vendors"] = {
"news_data": "yfinance", # Options: alpha_vantage, yfinance
}
# Enable memory persistence so lessons survive restarts (optional).
# Set to None or omit to keep the default RAM-only behaviour.
config["memory_persist_dir"] = "~/.tradingagents/memory"
# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)
@ -27,5 +31,10 @@ ta = TradingAgentsGraph(debug=True, config=config)
_, decision = ta.propagate("NVDA", "2024-05-10")
print(decision)
# Memorize mistakes and reflect
# ta.reflect_and_remember(1000) # parameter is the position returns
# Reflect on the decision after observing actual returns.
# Call this once the position closes and the P&L is known.
# The signed float indicates outcome: positive = correct signal,
# negative = incorrect signal. Lessons are persisted when
# memory_persist_dir is set, so the next TradingAgentsGraph
# instance will load them automatically.
# ta.reflect_and_remember(0.03) # e.g. 3 % gain

View File

@ -0,0 +1,292 @@
"""Tests for FinancialSituationMemory JSON file persistence.
Covers:
- RAM-only mode (no config / config without key) behaves identically to before
- Persistence: add_situations restart memories survive
- Persistence: clear restart file reflects empty state
- Atomic write: .tmp .json rename
- Corrupt / missing file is handled gracefully
- BM25 index is rebuilt after loading from disk
- Multiple memory instances sharing the same dir don't collide
- Default config includes memory_persist_dir key
- TradingAgentsGraph passes config through to all five memory instances
"""
import json
import os
import pathlib
import tempfile
import unittest
from tradingagents.agents.utils.memory import FinancialSituationMemory
from tradingagents.default_config import DEFAULT_CONFIG
SAMPLE_DATA = [
(
"High inflation rate with rising interest rates",
"Consider defensive sectors like utilities.",
),
(
"Tech sector volatility with institutional selling pressure",
"Reduce high-growth tech exposure.",
),
(
"Strong dollar affecting emerging markets",
"Hedge currency exposure in international positions.",
),
]
class TestRamOnlyMode(unittest.TestCase):
"""Ensure RAM-only mode (no persistence) is fully backward-compatible."""
def test_no_config(self):
mem = FinancialSituationMemory("test")
self.assertIsNone(mem._persist_path)
mem.add_situations(SAMPLE_DATA[:1])
self.assertEqual(len(mem.documents), 1)
def test_config_without_persist_key(self):
mem = FinancialSituationMemory("test", config={"some_other_key": True})
self.assertIsNone(mem._persist_path)
def test_config_with_none_persist_dir(self):
mem = FinancialSituationMemory("test", config={"memory_persist_dir": None})
self.assertIsNone(mem._persist_path)
def test_config_with_empty_string_persist_dir(self):
mem = FinancialSituationMemory("test", config={"memory_persist_dir": ""})
self.assertIsNone(mem._persist_path)
def test_ram_only_data_lost_on_new_instance(self):
mem1 = FinancialSituationMemory("test")
mem1.add_situations(SAMPLE_DATA)
self.assertEqual(len(mem1.documents), 3)
mem2 = FinancialSituationMemory("test")
self.assertEqual(len(mem2.documents), 0)
class TestPersistence(unittest.TestCase):
"""Core persistence round-trip tests."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_file_created_on_add(self):
mem = FinancialSituationMemory("bull_memory", config=self.config)
self.assertFalse(
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
"File should not exist before adding data",
)
mem.add_situations(SAMPLE_DATA[:1])
self.assertTrue(
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
"File should be created after add_situations",
)
def test_round_trip(self):
"""Data survives a full destroy-and-recreate cycle."""
mem1 = FinancialSituationMemory("rt", config=self.config)
mem1.add_situations(SAMPLE_DATA)
self.assertEqual(len(mem1.documents), 3)
# Destroy first instance, create a new one — simulates process restart
del mem1
mem2 = FinancialSituationMemory("rt", config=self.config)
self.assertEqual(len(mem2.documents), 3)
self.assertEqual(mem2.documents[0], SAMPLE_DATA[0][0])
self.assertEqual(mem2.recommendations[1], SAMPLE_DATA[1][1])
def test_incremental_add_persists(self):
"""Multiple add_situations calls accumulate correctly on disk."""
mem = FinancialSituationMemory("inc", config=self.config)
mem.add_situations(SAMPLE_DATA[:1])
mem.add_situations(SAMPLE_DATA[1:2])
del mem
mem2 = FinancialSituationMemory("inc", config=self.config)
self.assertEqual(len(mem2.documents), 2)
def test_clear_persists_empty(self):
mem = FinancialSituationMemory("clr", config=self.config)
mem.add_situations(SAMPLE_DATA)
mem.clear()
del mem
mem2 = FinancialSituationMemory("clr", config=self.config)
self.assertEqual(len(mem2.documents), 0)
self.assertEqual(len(mem2.recommendations), 0)
def test_bm25_rebuilt_after_load(self):
"""BM25 index should work after loading from disk."""
mem1 = FinancialSituationMemory("bm25", config=self.config)
mem1.add_situations(SAMPLE_DATA)
del mem1
mem2 = FinancialSituationMemory("bm25", config=self.config)
results = mem2.get_memories("rising interest rates inflation", n_matches=1)
self.assertEqual(len(results), 1)
self.assertIn("inflation", results[0]["matched_situation"].lower())
def test_json_file_content_valid(self):
"""Persisted JSON has the expected schema."""
mem = FinancialSituationMemory("schema", config=self.config)
mem.add_situations(SAMPLE_DATA[:2])
fp = pathlib.Path(self.tmpdir) / "schema.json"
data = json.loads(fp.read_text(encoding="utf-8"))
self.assertIn("situations", data)
self.assertIn("recommendations", data)
self.assertEqual(len(data["situations"]), 2)
self.assertEqual(len(data["recommendations"]), 2)
def test_unicode_round_trip(self):
"""Non-ASCII data (e.g. CJK) survives persistence."""
mem = FinancialSituationMemory("uni", config=self.config)
mem.add_situations([("通胀上升,利率攀升", "考虑防御性板块")])
del mem
mem2 = FinancialSituationMemory("uni", config=self.config)
self.assertEqual(mem2.documents[0], "通胀上升,利率攀升")
self.assertEqual(mem2.recommendations[0], "考虑防御性板块")
class TestEdgeCases(unittest.TestCase):
"""Graceful handling of corrupt / missing files."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_corrupt_json_starts_fresh(self):
fp = pathlib.Path(self.tmpdir) / "bad.json"
fp.write_text("NOT VALID JSON {{{", encoding="utf-8")
mem = FinancialSituationMemory("bad", config=self.config)
self.assertEqual(len(mem.documents), 0, "Should start fresh on corrupt file")
def test_missing_keys_starts_fresh(self):
fp = pathlib.Path(self.tmpdir) / "partial.json"
fp.write_text(json.dumps({"unrelated": True}), encoding="utf-8")
mem = FinancialSituationMemory("partial", config=self.config)
self.assertEqual(len(mem.documents), 0)
def test_mismatched_lengths_truncates(self):
"""If situations and recommendations have different lengths, zip truncates."""
fp = pathlib.Path(self.tmpdir) / "mismatch.json"
fp.write_text(json.dumps({
"situations": ["s1", "s2", "s3"],
"recommendations": ["r1", "r2"],
}), encoding="utf-8")
mem = FinancialSituationMemory("mismatch", config=self.config)
self.assertEqual(len(mem.documents), 2)
self.assertEqual(len(mem.recommendations), 2)
def test_parent_dirs_created(self):
nested = os.path.join(self.tmpdir, "a", "b", "c")
config = {"memory_persist_dir": nested}
mem = FinancialSituationMemory("nested", config=config)
mem.add_situations(SAMPLE_DATA[:1])
self.assertTrue(pathlib.Path(nested).is_dir())
self.assertTrue((pathlib.Path(nested) / "nested.json").exists())
def test_tilde_expansion(self):
"""~ in persist dir is expanded."""
config = {"memory_persist_dir": "~/.__test_tradingagents_mem__"}
try:
mem = FinancialSituationMemory("tilde", config=config)
self.assertFalse(str(mem._persist_path).startswith("~"))
self.assertIn(os.path.expanduser("~"), str(mem._persist_path))
finally:
import shutil
shutil.rmtree(
os.path.expanduser("~/.__test_tradingagents_mem__"),
ignore_errors=True,
)
class TestMultipleInstances(unittest.TestCase):
"""Multiple memory instances in the same directory don't collide."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.config = {"memory_persist_dir": self.tmpdir}
def tearDown(self):
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_independent_files(self):
bull = FinancialSituationMemory("bull_memory", config=self.config)
bear = FinancialSituationMemory("bear_memory", config=self.config)
bull.add_situations(SAMPLE_DATA[:1])
bear.add_situations(SAMPLE_DATA[1:2])
del bull, bear
bull2 = FinancialSituationMemory("bull_memory", config=self.config)
bear2 = FinancialSituationMemory("bear_memory", config=self.config)
self.assertEqual(len(bull2.documents), 1)
self.assertEqual(len(bear2.documents), 1)
self.assertNotEqual(bull2.documents[0], bear2.documents[0])
class TestDefaultConfig(unittest.TestCase):
"""Verify the default config includes the new key."""
def test_key_exists(self):
self.assertIn("memory_persist_dir", DEFAULT_CONFIG)
def test_default_is_none(self):
self.assertIsNone(
DEFAULT_CONFIG["memory_persist_dir"],
"Default should be None (RAM-only) for backward compatibility",
)
class TestTradingGraphIntegration(unittest.TestCase):
"""Verify TradingAgentsGraph passes config to all memory instances.
This is a source-code audit test it does NOT instantiate the graph
(which requires LLM API keys and heavy dependencies), but instead
inspects the constructor to confirm the config dict is forwarded.
"""
def test_config_forwarded_to_memory_constructors(self):
"""All five FinancialSituationMemory() calls receive self.config."""
import inspect
from tradingagents.graph.trading_graph import TradingAgentsGraph
source = inspect.getsource(TradingAgentsGraph.__init__)
memory_names = [
"bull_memory",
"bear_memory",
"trader_memory",
"invest_judge_memory",
"portfolio_manager_memory",
]
for name in memory_names:
# Expect: FinancialSituationMemory("name", self.config)
self.assertIn(
f'FinancialSituationMemory("{name}", self.config)',
source,
f"TradingAgentsGraph.__init__ should pass self.config to {name}",
)
if __name__ == "__main__":
unittest.main()

View File

@ -2,28 +2,100 @@
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
Supports optional JSON file persistence via config["memory_persist_dir"].
When set, memories survive process restarts; when unset, behaves as
before (RAM-only).
"""
import json
import pathlib
from rank_bm25 import BM25Okapi
from typing import List, Tuple
import re
class FinancialSituationMemory:
"""Memory system for storing and retrieving financial situations using BM25."""
"""Memory system for storing and retrieving financial situations using BM25.
Persistence
-----------
Pass ``config={"memory_persist_dir": "/some/path"}`` to enable JSON
file persistence. Each memory instance writes a
``<memory_persist_dir>/<name>.json`` file that is loaded on
construction and updated on every ``add_situations`` / ``clear``
call. When ``memory_persist_dir`` is *not* set (the default),
the class behaves identically to the original RAM-only version.
"""
def __init__(self, name: str, config: dict = None):
"""Initialize the memory system.
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
config: Configuration dict. Recognises the key
``memory_persist_dir`` when set to a directory path,
memories are persisted to ``<dir>/<name>.json``.
"""
self.name = name
self.documents: List[str] = []
self.recommendations: List[str] = []
self.bm25 = None
# Resolve persistence path (may be None → RAM-only)
self._persist_path: pathlib.Path | None = None
if config and config.get("memory_persist_dir"):
d = pathlib.Path(config["memory_persist_dir"]).expanduser()
d.mkdir(parents=True, exist_ok=True)
self._persist_path = d / f"{name}.json"
# Load previously persisted memories (no-op when path is None)
self._load()
# ------------------------------------------------------------------
# Persistence helpers
# ------------------------------------------------------------------
def _load(self):
"""Load previously persisted memories from disk (no-op when RAM-only)."""
if self._persist_path is None or not self._persist_path.exists():
return
try:
data = json.loads(self._persist_path.read_text(encoding="utf-8"))
situations = data.get("situations", [])
recommendations = data.get("recommendations", [])
# Populate lists directly to avoid redundant _save() call
for situation, recommendation in zip(situations, recommendations):
self.documents.append(situation)
self.recommendations.append(recommendation)
if self.documents:
self._rebuild_index()
except (json.JSONDecodeError, OSError, AttributeError, TypeError):
# Corrupt / unreadable file — start fresh
pass
def _save(self):
"""Persist current memories to disk (no-op when RAM-only)."""
if self._persist_path is None:
return
payload = {
"situations": list(self.documents),
"recommendations": list(self.recommendations),
}
try:
# Atomic-ish write: write to temp then rename
tmp = self._persist_path.with_suffix(".tmp")
tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
tmp.replace(self._persist_path)
except OSError:
# Fail gracefully if disk is full or permissions are restricted
pass
# ------------------------------------------------------------------
# Tokenisation & indexing
# ------------------------------------------------------------------
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
@ -54,6 +126,9 @@ class FinancialSituationMemory:
# Rebuild BM25 index with new documents
self._rebuild_index()
# Persist to disk (no-op when RAM-only)
self._save()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
@ -96,6 +171,7 @@ class FinancialSituationMemory:
self.documents = []
self.recommendations = []
self.bm25 = None
self._save() # Persist the empty state
if __name__ == "__main__":

View File

@ -18,6 +18,8 @@ DEFAULT_CONFIG = {
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
# Memory persistence (None = RAM-only, path = JSON file persistence)
"memory_persist_dir": None,
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,