TradingAgents/tradingagents/memory/macro_memory.py

282 lines
10 KiB
Python

"""Macro memory — learn from past regime-level market context.
Stores macro regime states (VIX level, risk-on/off call, sector thesis, key
themes) and later associates outcomes, enabling agents to *reflect* on
regime accuracy and adjust forward-looking bias accordingly.
Unlike ReflexionMemory (which is per-ticker), MacroMemory operates at the
market-wide level. Each record captures the macro environment on a given date,
independent of any single security.
Backed by MongoDB when available; falls back to a local JSON file when not.
Schema (``macro_memory`` collection)::
{
"regime_date": str, # ISO date "2026-03-26"
"vix_level": float, # e.g. 25.3
"macro_call": str, # "risk-on" | "risk-off" | "neutral" | "transition"
"sector_thesis": str, # free-form regime summary
"key_themes": list, # list of top macro theme strings
"run_id": str | None,
"outcome": dict | None, # filled later by record_outcome()
"created_at": datetime,
}
Usage::
from tradingagents.memory.macro_memory import MacroMemory
mem = MacroMemory("mongodb://localhost:27017")
mem.record_macro_state(
date="2026-03-26",
vix_level=25.3,
macro_call="risk-off",
sector_thesis="Energy under pressure, Fed hawkish",
key_themes=["rate hikes", "oil volatility"],
)
context = mem.build_macro_context(limit=3)
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_COLLECTION = "macro_memory"
_VALID_MACRO_CALLS = {"risk-on", "risk-off", "neutral", "transition"}
class MacroMemory:
"""MongoDB-backed macro regime memory.
Falls back to a local JSON file when MongoDB is unavailable, so the
feature always works (though with degraded query performance on the
local variant).
"""
def __init__(
self,
mongo_uri: str | None = None,
db_name: str = "tradingagents",
fallback_path: str | Path = "reports/macro_memory.json",
) -> None:
self._col = None
self._fallback_path = Path(fallback_path)
if mongo_uri:
try:
from pymongo import DESCENDING, MongoClient
client = MongoClient(mongo_uri)
db = client[db_name]
self._col = db[_COLLECTION]
self._col.create_index([("regime_date", DESCENDING)])
self._col.create_index("created_at")
logger.info("MacroMemory using MongoDB (db=%s)", db_name)
except Exception:
logger.warning(
"MacroMemory: MongoDB unavailable — using local file",
exc_info=True,
)
# ------------------------------------------------------------------
# Record macro state
# ------------------------------------------------------------------
def record_macro_state(
self,
date: str,
vix_level: float,
macro_call: str,
sector_thesis: str,
key_themes: list[str],
run_id: str | None = None,
) -> None:
"""Store a macro regime state for later reflection.
Args:
date: ISO date string, e.g. "2026-03-26".
vix_level: VIX index level at the time of the call.
macro_call: Regime classification: "risk-on", "risk-off",
"neutral", or "transition".
sector_thesis: Free-form summary of the prevailing sector view.
key_themes: Top macro themes driving the regime call.
run_id: Optional run identifier for traceability.
"""
normalized_call = macro_call.lower()
if normalized_call not in _VALID_MACRO_CALLS:
logger.warning(
"MacroMemory: unexpected macro_call %r (expected one of %s)",
macro_call,
_VALID_MACRO_CALLS,
)
doc: dict[str, Any] = {
"regime_date": date,
"vix_level": float(vix_level),
"macro_call": normalized_call,
"sector_thesis": sector_thesis,
"key_themes": list(key_themes),
"run_id": run_id,
"outcome": None,
"created_at": datetime.now(timezone.utc),
}
if self._col is not None:
self._col.insert_one(doc)
else:
# Local JSON fallback uses ISO string (JSON has no datetime type)
doc["created_at"] = doc["created_at"].isoformat()
self._append_local(doc)
# ------------------------------------------------------------------
# Record outcome (feedback loop)
# ------------------------------------------------------------------
def record_outcome(self, date: str, outcome: dict[str, Any]) -> bool:
"""Attach outcome to the most recent macro state for a given date.
Args:
date: ISO date string matching the original ``regime_date``.
outcome: Dict with evaluation data, e.g.::
{
"evaluation_date": "2026-04-26",
"vix_at_evaluation": 18.2,
"regime_confirmed": True,
"notes": "Risk-off call was correct; market sold off",
}
Returns:
True if a matching state was found and updated.
"""
if self._col is not None:
from pymongo import DESCENDING
doc = self._col.find_one_and_update(
{"regime_date": date, "outcome": None},
{"$set": {"outcome": outcome}},
sort=[("created_at", DESCENDING)],
)
return doc is not None
else:
return self._update_local_outcome(date, outcome)
# ------------------------------------------------------------------
# Query
# ------------------------------------------------------------------
def get_recent(self, limit: int = 3) -> list[dict[str, Any]]:
"""Return most recent macro states, newest first.
Args:
limit: Maximum number of results.
"""
if self._col is not None:
from pymongo import DESCENDING
cursor = self._col.find(
{},
{"_id": 0},
).sort("regime_date", DESCENDING).limit(limit)
return list(cursor)
else:
return self._load_recent_local(limit)
def build_macro_context(self, limit: int = 3) -> str:
"""Build a human-readable context string from recent macro states.
Suitable for injection into agent prompts. Returns a multi-line string
summarising recent regime calls and outcomes.
Format example::
- [2026-03-20] risk-off (VIX: 25.3)
Thesis: Energy sector under pressure, Fed hawkish
Themes: ['rate hikes', 'oil volatility']
Outcome: pending
Args:
limit: How many past states to include.
Returns:
Multi-line string summarising recent macro regime states.
"""
recent = self.get_recent(limit=limit)
if not recent:
return "No prior macro regime states recorded."
lines: list[str] = []
for rec in recent:
dt = rec.get("regime_date", "?")
call = rec.get("macro_call", "?")
vix = rec.get("vix_level", "?")
thesis = rec.get("sector_thesis", "")[:300]
themes = rec.get("key_themes", [])
outcome = rec.get("outcome")
if outcome:
confirmed = outcome.get("regime_confirmed", "?")
notes = outcome.get("notes", "")
outcome_str = f" Outcome: confirmed={confirmed}{notes}" if notes else f" Outcome: confirmed={confirmed}"
else:
outcome_str = " Outcome: pending"
lines.append(
f"- [{dt}] {call} (VIX: {vix})\n"
f" Thesis: {thesis}\n"
f" Themes: {themes}\n"
f"{outcome_str}"
)
return "\n".join(lines)
# ------------------------------------------------------------------
# Local JSON fallback
# ------------------------------------------------------------------
def _load_all_local(self) -> list[dict[str, Any]]:
"""Load all records from the local JSON file."""
if not self._fallback_path.exists():
return []
try:
return json.loads(self._fallback_path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
return []
def _save_all_local(self, records: list[dict[str, Any]]) -> None:
"""Overwrite the local JSON file with all records."""
self._fallback_path.parent.mkdir(parents=True, exist_ok=True)
self._fallback_path.write_text(
json.dumps(records, indent=2), encoding="utf-8"
)
def _append_local(self, doc: dict[str, Any]) -> None:
"""Append a single record to the local file."""
records = self._load_all_local()
records.append(doc)
self._save_all_local(records)
def _load_recent_local(self, limit: int) -> list[dict[str, Any]]:
"""Load and sort all records by regime_date descending from the local file."""
records = self._load_all_local()
records.sort(key=lambda r: r.get("regime_date", ""), reverse=True)
return records[:limit]
def _update_local_outcome(self, date: str, outcome: dict[str, Any]) -> bool:
"""Update the most recent matching macro state in the local file."""
records = self._load_all_local()
# Iterate newest first (reversed insertion order is a proxy)
for rec in reversed(records):
if rec.get("regime_date") == date and rec.get("outcome") is None:
rec["outcome"] = outcome
self._save_all_local(records)
return True
return False