TradingAgents/tradingagents/agents/utils/memory.py

273 lines
10 KiB
Python

"""Append-only markdown decision log for TradingAgents."""
from typing import List, Optional
from pathlib import Path
import re
class TradingMemoryLog:
"""Append-only markdown log of trading decisions and reflections."""
RATINGS = {"buy", "overweight", "hold", "underweight", "sell"}
# HTML comment: cannot appear in LLM prose output, safe as a hard delimiter
_SEPARATOR = "\n\n<!-- ENTRY_END -->\n\n"
# Precompiled patterns — avoids re-compilation on every load_entries() call
_DECISION_RE = re.compile(r"DECISION:\n(.*?)(?=\nREFLECTION:|\Z)", re.DOTALL)
_REFLECTION_RE = re.compile(r"REFLECTION:\n(.*?)$", re.DOTALL)
_RATING_LABEL_RE = re.compile(r"rating.*?[:\-]\s*(\w+)", re.IGNORECASE)
def __init__(self, config: dict = None):
self._log_path = None
path = (config or {}).get("memory_log_path")
if path:
self._log_path = Path(path).expanduser()
self._log_path.parent.mkdir(parents=True, exist_ok=True)
# --- Write path (Phase A) ---
def store_decision(
self,
ticker: str,
trade_date: str,
final_trade_decision: str,
) -> None:
"""Append pending entry at end of propagate(). No LLM call."""
if not self._log_path:
return
# Idempotency guard: fast raw-text scan instead of full parse
if self._log_path.exists():
raw = self._log_path.read_text(encoding="utf-8")
for line in raw.splitlines():
if line.startswith(f"[{trade_date} | {ticker} |") and line.endswith("| pending]"):
return
rating = self._parse_rating(final_trade_decision)
tag = f"[{trade_date} | {ticker} | {rating} | pending]"
entry = f"{tag}\n\nDECISION:\n{final_trade_decision}{self._SEPARATOR}"
with open(self._log_path, "a", encoding="utf-8") as f:
f.write(entry)
# --- Read path (Phase A) ---
def load_entries(self) -> List[dict]:
"""Parse all entries from log. Returns list of dicts."""
if not self._log_path or not self._log_path.exists():
return []
text = self._log_path.read_text(encoding="utf-8")
raw_entries = [e.strip() for e in text.split(self._SEPARATOR) if e.strip()]
entries = []
for raw in raw_entries:
parsed = self._parse_entry(raw)
if parsed:
entries.append(parsed)
return entries
def get_pending_entries(self) -> List[dict]:
"""Return entries with outcome:pending (for Phase B)."""
return [e for e in self.load_entries() if e.get("pending")]
def get_past_context(self, ticker: str, n_same: int = 5, n_cross: int = 3) -> str:
"""Return formatted past context string for agent prompt injection."""
entries = [e for e in self.load_entries() if not e.get("pending")]
if not entries:
return ""
same, cross = [], []
for e in reversed(entries):
if len(same) >= n_same and len(cross) >= n_cross:
break
if e["ticker"] == ticker and len(same) < n_same:
same.append(e)
elif e["ticker"] != ticker and len(cross) < n_cross:
cross.append(e)
if not same and not cross:
return ""
parts = []
if same:
parts.append(f"Past analyses of {ticker} (most recent first):")
parts.extend(self._format_full(e) for e in same)
if cross:
parts.append("Recent cross-ticker lessons:")
parts.extend(self._format_reflection_only(e) for e in cross)
return "\n\n".join(parts)
# --- Update path (Phase B) ---
def update_with_outcome(
self,
ticker: str,
trade_date: str,
raw_return: float,
alpha_return: float,
holding_days: int,
reflection: str,
) -> None:
"""Replace pending tag and append REFLECTION section using atomic write.
Finds the first pending entry matching (trade_date, ticker), updates
its tag with return figures, and appends a REFLECTION section. Uses
a temp-file + os.replace() so a crash mid-write never corrupts the log.
"""
if not self._log_path or not self._log_path.exists():
return
text = self._log_path.read_text(encoding="utf-8")
blocks = text.split(self._SEPARATOR)
pending_prefix = f"[{trade_date} | {ticker} |"
raw_pct = f"{raw_return:+.1%}"
alpha_pct = f"{alpha_return:+.1%}"
updated = False
new_blocks = []
for block in blocks:
stripped = block.strip()
if not stripped:
new_blocks.append(block)
continue
lines = stripped.splitlines()
tag_line = lines[0].strip()
if (
not updated
and tag_line.startswith(pending_prefix)
and tag_line.endswith("| pending]")
):
# Parse rating from the existing pending tag
fields = [f.strip() for f in tag_line[1:-1].split("|")]
rating = fields[2]
new_tag = (
f"[{trade_date} | {ticker} | {rating}"
f" | {raw_pct} | {alpha_pct} | {holding_days}d]"
)
rest = "\n".join(lines[1:])
new_blocks.append(
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{reflection}"
)
updated = True
else:
new_blocks.append(block)
if not updated:
return
new_text = self._SEPARATOR.join(new_blocks)
tmp_path = self._log_path.with_suffix(".tmp")
tmp_path.write_text(new_text, encoding="utf-8")
tmp_path.replace(self._log_path)
def batch_update_with_outcomes(self, updates: List[dict]) -> None:
"""Apply multiple outcome updates in a single read + atomic write.
Each element of updates must have keys: ticker, trade_date,
raw_return, alpha_return, holding_days, reflection.
"""
if not self._log_path or not self._log_path.exists() or not updates:
return
text = self._log_path.read_text(encoding="utf-8")
blocks = text.split(self._SEPARATOR)
# Build lookup keyed by (trade_date, ticker) for O(1) dispatch
update_map = {(u["trade_date"], u["ticker"]): u for u in updates}
new_blocks = []
for block in blocks:
stripped = block.strip()
if not stripped:
new_blocks.append(block)
continue
lines = stripped.splitlines()
tag_line = lines[0].strip()
matched = False
for (trade_date, ticker), upd in list(update_map.items()):
pending_prefix = f"[{trade_date} | {ticker} |"
if tag_line.startswith(pending_prefix) and tag_line.endswith("| pending]"):
fields = [f.strip() for f in tag_line[1:-1].split("|")]
rating = fields[2]
raw_pct = f"{upd['raw_return']:+.1%}"
alpha_pct = f"{upd['alpha_return']:+.1%}"
new_tag = (
f"[{trade_date} | {ticker} | {rating}"
f" | {raw_pct} | {alpha_pct} | {upd['holding_days']}d]"
)
rest = "\n".join(lines[1:])
new_blocks.append(
f"{new_tag}\n\n{rest.lstrip()}\n\nREFLECTION:\n{upd['reflection']}"
)
del update_map[(trade_date, ticker)]
matched = True
break
if not matched:
new_blocks.append(block)
new_text = self._SEPARATOR.join(new_blocks)
tmp_path = self._log_path.with_suffix(".tmp")
tmp_path.write_text(new_text, encoding="utf-8")
tmp_path.replace(self._log_path)
# --- Helpers ---
def _parse_rating(self, text: str) -> str:
# First pass: explicit "Rating: X" label — search handles markdown bold/numbered lists
for line in text.splitlines():
m = self._RATING_LABEL_RE.search(line)
if m and m.group(1).lower() in self.RATINGS:
return m.group(1).capitalize()
# Fallback: first rating word found anywhere in the text
for line in text.splitlines():
for word in line.lower().split():
clean = word.strip("*:.,")
if clean in self.RATINGS:
return clean.capitalize()
return "Hold"
def _parse_entry(self, raw: str) -> Optional[dict]:
lines = raw.strip().splitlines()
if not lines:
return None
tag_line = lines[0].strip()
if not (tag_line.startswith("[") and tag_line.endswith("]")):
return None
fields = [f.strip() for f in tag_line[1:-1].split("|")]
if len(fields) < 4:
return None
entry = {
"date": fields[0],
"ticker": fields[1],
"rating": fields[2],
"pending": fields[3] == "pending",
"raw": fields[3] if fields[3] != "pending" else None,
"alpha": fields[4] if len(fields) > 4 else None,
"holding": fields[5] if len(fields) > 5 else None,
}
body = "\n".join(lines[1:]).strip()
decision_match = self._DECISION_RE.search(body)
reflection_match = self._REFLECTION_RE.search(body)
entry["decision"] = decision_match.group(1).strip() if decision_match else ""
entry["reflection"] = reflection_match.group(1).strip() if reflection_match else ""
return entry
def _format_full(self, e: dict) -> str:
raw = e["raw"] or "n/a"
alpha = e["alpha"] or "n/a"
holding = e["holding"] or "n/a"
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {raw} | {alpha} | {holding}]"
parts = [tag, f"DECISION:\n{e['decision']}"]
if e["reflection"]:
parts.append(f"REFLECTION:\n{e['reflection']}")
return "\n\n".join(parts)
def _format_reflection_only(self, e: dict) -> str:
tag = f"[{e['date']} | {e['ticker']} | {e['rating']} | {e['raw'] or 'n/a'}]"
if e["reflection"]:
return f"{tag}\n{e['reflection']}"
text = e["decision"][:300]
suffix = "..." if len(e["decision"]) > 300 else ""
return f"{tag}\n{text}{suffix}"