Merge 9ee18699fd into fa4d01c23a
This commit is contained in:
commit
fbf51db362
13
main.py
13
main.py
|
|
@ -20,6 +20,10 @@ config["data_vendors"] = {
|
|||
"news_data": "yfinance", # Options: alpha_vantage, yfinance
|
||||
}
|
||||
|
||||
# Enable memory persistence so lessons survive restarts (optional).
|
||||
# Set to None or omit to keep the default RAM-only behaviour.
|
||||
config["memory_persist_dir"] = "~/.tradingagents/memory"
|
||||
|
||||
# Initialize with custom config
|
||||
ta = TradingAgentsGraph(debug=True, config=config)
|
||||
|
||||
|
|
@ -27,5 +31,10 @@ ta = TradingAgentsGraph(debug=True, config=config)
|
|||
_, decision = ta.propagate("NVDA", "2024-05-10")
|
||||
print(decision)
|
||||
|
||||
# Memorize mistakes and reflect
|
||||
# ta.reflect_and_remember(1000) # parameter is the position returns
|
||||
# Reflect on the decision after observing actual returns.
|
||||
# Call this once the position closes and the P&L is known.
|
||||
# The signed float indicates outcome: positive = correct signal,
|
||||
# negative = incorrect signal. Lessons are persisted when
|
||||
# memory_persist_dir is set, so the next TradingAgentsGraph
|
||||
# instance will load them automatically.
|
||||
# ta.reflect_and_remember(0.03) # e.g. 3 % gain
|
||||
|
|
|
|||
|
|
@ -0,0 +1,292 @@
|
|||
"""Tests for FinancialSituationMemory JSON file persistence.
|
||||
|
||||
Covers:
|
||||
- RAM-only mode (no config / config without key) behaves identically to before
|
||||
- Persistence: add_situations → restart → memories survive
|
||||
- Persistence: clear → restart → file reflects empty state
|
||||
- Atomic write: .tmp → .json rename
|
||||
- Corrupt / missing file is handled gracefully
|
||||
- BM25 index is rebuilt after loading from disk
|
||||
- Multiple memory instances sharing the same dir don't collide
|
||||
- Default config includes memory_persist_dir key
|
||||
- TradingAgentsGraph passes config through to all five memory instances
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
|
||||
|
||||
SAMPLE_DATA = [
|
||||
(
|
||||
"High inflation rate with rising interest rates",
|
||||
"Consider defensive sectors like utilities.",
|
||||
),
|
||||
(
|
||||
"Tech sector volatility with institutional selling pressure",
|
||||
"Reduce high-growth tech exposure.",
|
||||
),
|
||||
(
|
||||
"Strong dollar affecting emerging markets",
|
||||
"Hedge currency exposure in international positions.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestRamOnlyMode(unittest.TestCase):
|
||||
"""Ensure RAM-only mode (no persistence) is fully backward-compatible."""
|
||||
|
||||
def test_no_config(self):
|
||||
mem = FinancialSituationMemory("test")
|
||||
self.assertIsNone(mem._persist_path)
|
||||
mem.add_situations(SAMPLE_DATA[:1])
|
||||
self.assertEqual(len(mem.documents), 1)
|
||||
|
||||
def test_config_without_persist_key(self):
|
||||
mem = FinancialSituationMemory("test", config={"some_other_key": True})
|
||||
self.assertIsNone(mem._persist_path)
|
||||
|
||||
def test_config_with_none_persist_dir(self):
|
||||
mem = FinancialSituationMemory("test", config={"memory_persist_dir": None})
|
||||
self.assertIsNone(mem._persist_path)
|
||||
|
||||
def test_config_with_empty_string_persist_dir(self):
|
||||
mem = FinancialSituationMemory("test", config={"memory_persist_dir": ""})
|
||||
self.assertIsNone(mem._persist_path)
|
||||
|
||||
def test_ram_only_data_lost_on_new_instance(self):
|
||||
mem1 = FinancialSituationMemory("test")
|
||||
mem1.add_situations(SAMPLE_DATA)
|
||||
self.assertEqual(len(mem1.documents), 3)
|
||||
|
||||
mem2 = FinancialSituationMemory("test")
|
||||
self.assertEqual(len(mem2.documents), 0)
|
||||
|
||||
|
||||
class TestPersistence(unittest.TestCase):
|
||||
"""Core persistence round-trip tests."""
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.config = {"memory_persist_dir": self.tmpdir}
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
||||
|
||||
def test_file_created_on_add(self):
|
||||
mem = FinancialSituationMemory("bull_memory", config=self.config)
|
||||
self.assertFalse(
|
||||
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
|
||||
"File should not exist before adding data",
|
||||
)
|
||||
mem.add_situations(SAMPLE_DATA[:1])
|
||||
self.assertTrue(
|
||||
(pathlib.Path(self.tmpdir) / "bull_memory.json").exists(),
|
||||
"File should be created after add_situations",
|
||||
)
|
||||
|
||||
def test_round_trip(self):
|
||||
"""Data survives a full destroy-and-recreate cycle."""
|
||||
mem1 = FinancialSituationMemory("rt", config=self.config)
|
||||
mem1.add_situations(SAMPLE_DATA)
|
||||
self.assertEqual(len(mem1.documents), 3)
|
||||
|
||||
# Destroy first instance, create a new one — simulates process restart
|
||||
del mem1
|
||||
mem2 = FinancialSituationMemory("rt", config=self.config)
|
||||
self.assertEqual(len(mem2.documents), 3)
|
||||
self.assertEqual(mem2.documents[0], SAMPLE_DATA[0][0])
|
||||
self.assertEqual(mem2.recommendations[1], SAMPLE_DATA[1][1])
|
||||
|
||||
def test_incremental_add_persists(self):
|
||||
"""Multiple add_situations calls accumulate correctly on disk."""
|
||||
mem = FinancialSituationMemory("inc", config=self.config)
|
||||
mem.add_situations(SAMPLE_DATA[:1])
|
||||
mem.add_situations(SAMPLE_DATA[1:2])
|
||||
|
||||
del mem
|
||||
mem2 = FinancialSituationMemory("inc", config=self.config)
|
||||
self.assertEqual(len(mem2.documents), 2)
|
||||
|
||||
def test_clear_persists_empty(self):
|
||||
mem = FinancialSituationMemory("clr", config=self.config)
|
||||
mem.add_situations(SAMPLE_DATA)
|
||||
mem.clear()
|
||||
|
||||
del mem
|
||||
mem2 = FinancialSituationMemory("clr", config=self.config)
|
||||
self.assertEqual(len(mem2.documents), 0)
|
||||
self.assertEqual(len(mem2.recommendations), 0)
|
||||
|
||||
def test_bm25_rebuilt_after_load(self):
|
||||
"""BM25 index should work after loading from disk."""
|
||||
mem1 = FinancialSituationMemory("bm25", config=self.config)
|
||||
mem1.add_situations(SAMPLE_DATA)
|
||||
|
||||
del mem1
|
||||
mem2 = FinancialSituationMemory("bm25", config=self.config)
|
||||
results = mem2.get_memories("rising interest rates inflation", n_matches=1)
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertIn("inflation", results[0]["matched_situation"].lower())
|
||||
|
||||
def test_json_file_content_valid(self):
|
||||
"""Persisted JSON has the expected schema."""
|
||||
mem = FinancialSituationMemory("schema", config=self.config)
|
||||
mem.add_situations(SAMPLE_DATA[:2])
|
||||
|
||||
fp = pathlib.Path(self.tmpdir) / "schema.json"
|
||||
data = json.loads(fp.read_text(encoding="utf-8"))
|
||||
self.assertIn("situations", data)
|
||||
self.assertIn("recommendations", data)
|
||||
self.assertEqual(len(data["situations"]), 2)
|
||||
self.assertEqual(len(data["recommendations"]), 2)
|
||||
|
||||
def test_unicode_round_trip(self):
|
||||
"""Non-ASCII data (e.g. CJK) survives persistence."""
|
||||
mem = FinancialSituationMemory("uni", config=self.config)
|
||||
mem.add_situations([("通胀上升,利率攀升", "考虑防御性板块")])
|
||||
|
||||
del mem
|
||||
mem2 = FinancialSituationMemory("uni", config=self.config)
|
||||
self.assertEqual(mem2.documents[0], "通胀上升,利率攀升")
|
||||
self.assertEqual(mem2.recommendations[0], "考虑防御性板块")
|
||||
|
||||
|
||||
class TestEdgeCases(unittest.TestCase):
|
||||
"""Graceful handling of corrupt / missing files."""
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.config = {"memory_persist_dir": self.tmpdir}
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
||||
|
||||
def test_corrupt_json_starts_fresh(self):
|
||||
fp = pathlib.Path(self.tmpdir) / "bad.json"
|
||||
fp.write_text("NOT VALID JSON {{{", encoding="utf-8")
|
||||
mem = FinancialSituationMemory("bad", config=self.config)
|
||||
self.assertEqual(len(mem.documents), 0, "Should start fresh on corrupt file")
|
||||
|
||||
def test_missing_keys_starts_fresh(self):
|
||||
fp = pathlib.Path(self.tmpdir) / "partial.json"
|
||||
fp.write_text(json.dumps({"unrelated": True}), encoding="utf-8")
|
||||
mem = FinancialSituationMemory("partial", config=self.config)
|
||||
self.assertEqual(len(mem.documents), 0)
|
||||
|
||||
def test_mismatched_lengths_truncates(self):
|
||||
"""If situations and recommendations have different lengths, zip truncates."""
|
||||
fp = pathlib.Path(self.tmpdir) / "mismatch.json"
|
||||
fp.write_text(json.dumps({
|
||||
"situations": ["s1", "s2", "s3"],
|
||||
"recommendations": ["r1", "r2"],
|
||||
}), encoding="utf-8")
|
||||
mem = FinancialSituationMemory("mismatch", config=self.config)
|
||||
self.assertEqual(len(mem.documents), 2)
|
||||
self.assertEqual(len(mem.recommendations), 2)
|
||||
|
||||
def test_parent_dirs_created(self):
|
||||
nested = os.path.join(self.tmpdir, "a", "b", "c")
|
||||
config = {"memory_persist_dir": nested}
|
||||
mem = FinancialSituationMemory("nested", config=config)
|
||||
mem.add_situations(SAMPLE_DATA[:1])
|
||||
self.assertTrue(pathlib.Path(nested).is_dir())
|
||||
self.assertTrue((pathlib.Path(nested) / "nested.json").exists())
|
||||
|
||||
def test_tilde_expansion(self):
|
||||
"""~ in persist dir is expanded."""
|
||||
config = {"memory_persist_dir": "~/.__test_tradingagents_mem__"}
|
||||
try:
|
||||
mem = FinancialSituationMemory("tilde", config=config)
|
||||
self.assertFalse(str(mem._persist_path).startswith("~"))
|
||||
self.assertIn(os.path.expanduser("~"), str(mem._persist_path))
|
||||
finally:
|
||||
import shutil
|
||||
shutil.rmtree(
|
||||
os.path.expanduser("~/.__test_tradingagents_mem__"),
|
||||
ignore_errors=True,
|
||||
)
|
||||
|
||||
|
||||
class TestMultipleInstances(unittest.TestCase):
|
||||
"""Multiple memory instances in the same directory don't collide."""
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.config = {"memory_persist_dir": self.tmpdir}
|
||||
|
||||
def tearDown(self):
|
||||
import shutil
|
||||
shutil.rmtree(self.tmpdir, ignore_errors=True)
|
||||
|
||||
def test_independent_files(self):
|
||||
bull = FinancialSituationMemory("bull_memory", config=self.config)
|
||||
bear = FinancialSituationMemory("bear_memory", config=self.config)
|
||||
|
||||
bull.add_situations(SAMPLE_DATA[:1])
|
||||
bear.add_situations(SAMPLE_DATA[1:2])
|
||||
|
||||
del bull, bear
|
||||
|
||||
bull2 = FinancialSituationMemory("bull_memory", config=self.config)
|
||||
bear2 = FinancialSituationMemory("bear_memory", config=self.config)
|
||||
|
||||
self.assertEqual(len(bull2.documents), 1)
|
||||
self.assertEqual(len(bear2.documents), 1)
|
||||
self.assertNotEqual(bull2.documents[0], bear2.documents[0])
|
||||
|
||||
|
||||
class TestDefaultConfig(unittest.TestCase):
|
||||
"""Verify the default config includes the new key."""
|
||||
|
||||
def test_key_exists(self):
|
||||
self.assertIn("memory_persist_dir", DEFAULT_CONFIG)
|
||||
|
||||
def test_default_is_none(self):
|
||||
self.assertIsNone(
|
||||
DEFAULT_CONFIG["memory_persist_dir"],
|
||||
"Default should be None (RAM-only) for backward compatibility",
|
||||
)
|
||||
|
||||
|
||||
class TestTradingGraphIntegration(unittest.TestCase):
|
||||
"""Verify TradingAgentsGraph passes config to all memory instances.
|
||||
|
||||
This is a source-code audit test — it does NOT instantiate the graph
|
||||
(which requires LLM API keys and heavy dependencies), but instead
|
||||
inspects the constructor to confirm the config dict is forwarded.
|
||||
"""
|
||||
|
||||
def test_config_forwarded_to_memory_constructors(self):
|
||||
"""All five FinancialSituationMemory() calls receive self.config."""
|
||||
import inspect
|
||||
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
||||
|
||||
source = inspect.getsource(TradingAgentsGraph.__init__)
|
||||
|
||||
memory_names = [
|
||||
"bull_memory",
|
||||
"bear_memory",
|
||||
"trader_memory",
|
||||
"invest_judge_memory",
|
||||
"portfolio_manager_memory",
|
||||
]
|
||||
|
||||
for name in memory_names:
|
||||
# Expect: FinancialSituationMemory("name", self.config)
|
||||
self.assertIn(
|
||||
f'FinancialSituationMemory("{name}", self.config)',
|
||||
source,
|
||||
f"TradingAgentsGraph.__init__ should pass self.config to {name}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -2,28 +2,100 @@
|
|||
|
||||
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
|
||||
no token limits, works offline with any LLM provider.
|
||||
|
||||
Supports optional JSON file persistence via config["memory_persist_dir"].
|
||||
When set, memories survive process restarts; when unset, behaves as
|
||||
before (RAM-only).
|
||||
"""
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
"""Memory system for storing and retrieving financial situations using BM25."""
|
||||
"""Memory system for storing and retrieving financial situations using BM25.
|
||||
|
||||
Persistence
|
||||
-----------
|
||||
Pass ``config={"memory_persist_dir": "/some/path"}`` to enable JSON
|
||||
file persistence. Each memory instance writes a
|
||||
``<memory_persist_dir>/<name>.json`` file that is loaded on
|
||||
construction and updated on every ``add_situations`` / ``clear``
|
||||
call. When ``memory_persist_dir`` is *not* set (the default),
|
||||
the class behaves identically to the original RAM-only version.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, config: dict = None):
|
||||
"""Initialize the memory system.
|
||||
|
||||
Args:
|
||||
name: Name identifier for this memory instance
|
||||
config: Configuration dict (kept for API compatibility, not used for BM25)
|
||||
config: Configuration dict. Recognises the key
|
||||
``memory_persist_dir`` — when set to a directory path,
|
||||
memories are persisted to ``<dir>/<name>.json``.
|
||||
"""
|
||||
self.name = name
|
||||
self.documents: List[str] = []
|
||||
self.recommendations: List[str] = []
|
||||
self.bm25 = None
|
||||
|
||||
# Resolve persistence path (may be None → RAM-only)
|
||||
self._persist_path: pathlib.Path | None = None
|
||||
if config and config.get("memory_persist_dir"):
|
||||
d = pathlib.Path(config["memory_persist_dir"]).expanduser()
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
self._persist_path = d / f"{name}.json"
|
||||
|
||||
# Load previously persisted memories (no-op when path is None)
|
||||
self._load()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persistence helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _load(self):
|
||||
"""Load previously persisted memories from disk (no-op when RAM-only)."""
|
||||
if self._persist_path is None or not self._persist_path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._persist_path.read_text(encoding="utf-8"))
|
||||
situations = data.get("situations", [])
|
||||
recommendations = data.get("recommendations", [])
|
||||
# Populate lists directly to avoid redundant _save() call
|
||||
for situation, recommendation in zip(situations, recommendations):
|
||||
self.documents.append(situation)
|
||||
self.recommendations.append(recommendation)
|
||||
if self.documents:
|
||||
self._rebuild_index()
|
||||
except (json.JSONDecodeError, OSError, AttributeError, TypeError):
|
||||
# Corrupt / unreadable file — start fresh
|
||||
pass
|
||||
|
||||
def _save(self):
|
||||
"""Persist current memories to disk (no-op when RAM-only)."""
|
||||
if self._persist_path is None:
|
||||
return
|
||||
payload = {
|
||||
"situations": list(self.documents),
|
||||
"recommendations": list(self.recommendations),
|
||||
}
|
||||
try:
|
||||
# Atomic-ish write: write to temp then rename
|
||||
tmp = self._persist_path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
tmp.replace(self._persist_path)
|
||||
except OSError:
|
||||
# Fail gracefully if disk is full or permissions are restricted
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tokenisation & indexing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Tokenize text for BM25 indexing.
|
||||
|
||||
|
|
@ -54,6 +126,9 @@ class FinancialSituationMemory:
|
|||
# Rebuild BM25 index with new documents
|
||||
self._rebuild_index()
|
||||
|
||||
# Persist to disk (no-op when RAM-only)
|
||||
self._save()
|
||||
|
||||
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
||||
"""Find matching recommendations using BM25 similarity.
|
||||
|
||||
|
|
@ -96,6 +171,7 @@ class FinancialSituationMemory:
|
|||
self.documents = []
|
||||
self.recommendations = []
|
||||
self.bm25 = None
|
||||
self._save() # Persist the empty state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ DEFAULT_CONFIG = {
|
|||
# Output language for analyst reports and final decision
|
||||
# Internal agent debate stays in English for reasoning quality
|
||||
"output_language": "English",
|
||||
# Memory persistence (None = RAM-only, path = JSON file persistence)
|
||||
"memory_persist_dir": None,
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue