TradingAgents/tradingagents/strategies/backtest.py

359 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Backtest: compare decision quality with vs without strategy signals.
Loads historical TA decisions from two analysis runs:
- "baseline" (pre-strategy-signals, e.g. eval_results/ 2026-03-25)
- "enhanced" (with strategy signals, e.g. tradingagents/results/ 2026-04-14)
For each, retroactively computes strategy signals and measures:
1. Signaldecision alignment: did the TA decision agree with strategy signals?
2. Decision accuracy: did the TA decision predict the correct price direction?
3. Signal accuracy: did the strategy signals predict the correct price direction?
Outputs a JSON report + markdown summary.
Usage:
python -m tradingagents.strategies.backtest --baseline-date 2026-03-25 --enhanced-date 2026-04-14
"""
from __future__ import annotations
import argparse
import json
import re
import sys
from collections import defaultdict
from pathlib import Path
import yfinance as yf
from tradingagents.strategies import compute_signals
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
BULLISH_RATINGS = {"Buy", "Overweight"}
BEARISH_RATINGS = {"Sell", "Underweight"}
NEUTRAL_RATINGS = {"Hold"}
RATING_DIRECTION = {
"Buy": "BULLISH", "Overweight": "BULLISH",
"Sell": "BEARISH", "Underweight": "BEARISH",
"Hold": "NEUTRAL",
}
def _extract_rating(text: str) -> str:
m = re.search(r"Rating:\s*\*{0,2}(Buy|Sell|Hold|Overweight|Underweight)\*{0,2}", text, re.IGNORECASE)
return m.group(1).capitalize() if m else "Hold"
def _get_price_change(ticker: str, from_date: str, to_date: str) -> float | None:
"""Percentage price change between two dates."""
try:
hist = yf.Ticker(ticker).history(start=from_date, end=to_date)
if hist.empty or len(hist) < 2:
return None
return ((hist["Close"].iloc[-1] / hist["Close"].iloc[0]) - 1) * 100
except Exception:
return None
def _load_eval_results(date: str) -> dict[str, str]:
"""Load decisions from eval_results/{TICKER}/TradingAgentsStrategy_logs/."""
results = {}
base = Path("eval_results")
if not base.exists():
return results
for ticker_dir in base.iterdir():
if not ticker_dir.is_dir():
continue
f = ticker_dir / "TradingAgentsStrategy_logs" / f"full_states_log_{date}.json"
if not f.exists():
continue
try:
data = json.loads(f.read_text())
state = data.get(date, data) # nested or flat
ftd = state.get("final_trade_decision", "")
if ftd:
results[ticker_dir.name] = ftd
except Exception:
pass
return results
def _load_results(date: str) -> dict[str, str]:
"""Load decisions from tradingagents/results/{TICKER}/TradingAgentsStrategy_logs/."""
results = {}
base = Path("tradingagents/results")
if not base.exists():
return results
for ticker_dir in base.iterdir():
if not ticker_dir.is_dir():
continue
f = ticker_dir / "TradingAgentsStrategy_logs" / f"full_states_log_{date}.json"
if not f.exists():
continue
try:
data = json.loads(f.read_text())
state = data.get(date, data)
ftd = state.get("final_trade_decision", "")
if ftd:
results[ticker_dir.name] = ftd
except Exception:
pass
return results
def _signal_consensus(signals: list[dict]) -> str:
"""Determine overall signal consensus: BULLISH, BEARISH, or NEUTRAL."""
supports = sum(1 for s in signals if s.get("direction") == "SUPPORTS")
contradicts = sum(1 for s in signals if s.get("direction") == "CONTRADICTS")
if supports > contradicts:
return "BULLISH"
elif contradicts > supports:
return "BEARISH"
return "NEUTRAL"
# ---------------------------------------------------------------------------
# Core backtest
# ---------------------------------------------------------------------------
def backtest_run(
decisions: dict[str, str],
analysis_date: str,
eval_date: str,
label: str,
) -> list[dict]:
"""Score a set of decisions against actual price movement.
For each ticker:
- Extract rating from decision text
- Compute strategy signals retroactively for analysis_date
- Get actual price change from analysis_date to eval_date
- Score decision accuracy and signal accuracy
Returns list of per-ticker result dicts.
"""
results = []
for ticker in sorted(decisions):
ftd = decisions[ticker]
rating = _extract_rating(ftd)
rating_dir = RATING_DIRECTION.get(rating, "NEUTRAL")
# Actual price movement
pct = _get_price_change(ticker, analysis_date, eval_date)
if pct is None:
continue
actual_dir = "BULLISH" if pct > 1 else "BEARISH" if pct < -1 else "NEUTRAL"
# Retroactive strategy signals
try:
signals = compute_signals(ticker, analysis_date)
except Exception:
signals = []
sig_consensus = _signal_consensus(signals) if signals else "N/A"
# Decision accuracy: did rating predict direction?
if rating_dir == "NEUTRAL":
decision_correct = None # not a directional call
else:
decision_correct = (rating_dir == actual_dir)
# Signal accuracy: did consensus predict direction?
if sig_consensus in ("N/A", "NEUTRAL"):
signal_correct = None
else:
signal_correct = (sig_consensus == actual_dir)
# Alignment: did decision agree with signals?
if sig_consensus in ("N/A", "NEUTRAL") or rating_dir == "NEUTRAL":
aligned = None
else:
aligned = (rating_dir == sig_consensus)
n_supports = sum(1 for s in signals if s.get("direction") == "SUPPORTS")
n_contradicts = sum(1 for s in signals if s.get("direction") == "CONTRADICTS")
results.append({
"ticker": ticker,
"label": label,
"analysis_date": analysis_date,
"eval_date": eval_date,
"rating": rating,
"rating_direction": rating_dir,
"pct_change": round(pct, 2),
"actual_direction": actual_dir,
"decision_correct": decision_correct,
"signal_consensus": sig_consensus,
"signal_correct": signal_correct,
"aligned": aligned,
"n_signals": len(signals),
"n_supports": n_supports,
"n_contradicts": n_contradicts,
})
return results
def _accuracy(results: list[dict], key: str) -> tuple[int, int, float]:
"""Count correct/total/pct for a boolean key (skipping None)."""
scored = [r for r in results if r.get(key) is not None]
if not scored:
return 0, 0, 0.0
correct = sum(1 for r in scored if r[key])
return correct, len(scored), correct / len(scored) if scored else 0.0
def generate_report(baseline: list[dict], enhanced: list[dict], output_dir: Path) -> Path:
"""Generate markdown + JSON backtest comparison report."""
output_dir.mkdir(parents=True, exist_ok=True)
# JSON
all_results = {"baseline": baseline, "enhanced": enhanced}
json_path = output_dir / "backtest_results.json"
json_path.write_text(json.dumps(all_results, indent=2))
# Markdown
b_dec_c, b_dec_t, b_dec_pct = _accuracy(baseline, "decision_correct")
e_dec_c, e_dec_t, e_dec_pct = _accuracy(enhanced, "decision_correct")
b_sig_c, b_sig_t, b_sig_pct = _accuracy(baseline, "signal_correct")
e_sig_c, e_sig_t, e_sig_pct = _accuracy(enhanced, "signal_correct")
e_align_c, e_align_t, e_align_pct = _accuracy(enhanced, "aligned")
b_date = baseline[0]["analysis_date"] if baseline else "?"
e_date = enhanced[0]["analysis_date"] if enhanced else "?"
eval_date = enhanced[0]["eval_date"] if enhanced else baseline[0]["eval_date"] if baseline else "?"
lines = [
"# Strategy Signals Backtest Report\n",
f"Comparing decision quality **with** vs **without** strategy signals.\n",
"## Summary\n",
f"| Metric | Baseline ({b_date}) | Enhanced ({e_date}) | Delta |",
"|--------|---:|---:|---:|",
f"| Tickers analyzed | {len(baseline)} | {len(enhanced)} | |",
f"| Decision accuracy | {b_dec_c}/{b_dec_t} ({b_dec_pct:.0%}) | {e_dec_c}/{e_dec_t} ({e_dec_pct:.0%}) | {e_dec_pct - b_dec_pct:+.0%} |",
f"| Signal accuracy (retroactive) | {b_sig_c}/{b_sig_t} ({b_sig_pct:.0%}) | {e_sig_c}/{e_sig_t} ({e_sig_pct:.0%}) | {e_sig_pct - b_sig_pct:+.0%} |",
f"| Decisionsignal alignment | — | {e_align_c}/{e_align_t} ({e_align_pct:.0%}) | |",
f"| Evaluation date | {eval_date} | {eval_date} | |",
"",
"*Decision accuracy: did the rating (Buy/Sell) predict the correct price direction?*",
"*Signal accuracy: did the strategy signal consensus predict the correct direction?*",
"*Alignment: did the enhanced decision agree with strategy signals?*\n",
]
# Overlap analysis — tickers in both sets
b_tickers = {r["ticker"]: r for r in baseline}
e_tickers = {r["ticker"]: r for r in enhanced}
overlap = sorted(set(b_tickers) & set(e_tickers))
if overlap:
lines.append("## Head-to-Head (overlapping tickers)\n")
lines.append("| Ticker | Baseline Rating | Enhanced Rating | Actual Move | Baseline Correct | Enhanced Correct | Signals Agreed |")
lines.append("|--------|----------------|----------------|------------|:---:|:---:|:---:|")
for t in overlap:
b = b_tickers[t]
e = e_tickers[t]
b_icon = "" if b["decision_correct"] else "" if b["decision_correct"] is False else ""
e_icon = "" if e["decision_correct"] else "" if e["decision_correct"] is False else ""
a_icon = "" if e["aligned"] else "" if e["aligned"] is False else ""
lines.append(
f"| {t} | {b['rating']} | {e['rating']} | {e['pct_change']:+.1f}% | {b_icon} | {e_icon} | {a_icon} |"
)
# Overlap accuracy
o_baseline = [b_tickers[t] for t in overlap]
o_enhanced = [e_tickers[t] for t in overlap]
ob_c, ob_t, ob_pct = _accuracy(o_baseline, "decision_correct")
oe_c, oe_t, oe_pct = _accuracy(o_enhanced, "decision_correct")
lines.append(f"\nOverlap accuracy: baseline {ob_c}/{ob_t} ({ob_pct:.0%}) vs enhanced {oe_c}/{oe_t} ({oe_pct:.0%})\n")
# Per-strategy signal accuracy
all_signals_data: list[dict] = []
for r in baseline + enhanced:
ticker = r["ticker"]
date = r["analysis_date"]
actual = r["actual_direction"]
try:
sigs = compute_signals(ticker, date)
except Exception:
continue
for s in sigs:
d = s.get("direction", "NEUTRAL")
if d == "NEUTRAL":
continue
predicted = "BULLISH" if d == "SUPPORTS" else "BEARISH"
all_signals_data.append({
"strategy": s.get("name", "?"),
"correct": predicted == actual,
})
if all_signals_data:
strat_stats: dict[str, dict] = defaultdict(lambda: {"correct": 0, "total": 0})
for s in all_signals_data:
strat_stats[s["strategy"]]["total"] += 1
if s["correct"]:
strat_stats[s["strategy"]]["correct"] += 1
lines.append("## Per-Strategy Accuracy (across all tickers)\n")
lines.append("| Strategy | Correct | Total | Accuracy |")
lines.append("|----------|--------:|------:|---------:|")
for name in sorted(strat_stats, key=lambda n: strat_stats[n]["correct"] / max(strat_stats[n]["total"], 1), reverse=True):
st = strat_stats[name]
acc = st["correct"] / st["total"] if st["total"] else 0
display = name.replace("_", " ").title()
lines.append(f"| {display} | {st['correct']} | {st['total']} | {acc:.0%} |")
lines.append("")
lines.append(f"\n---\n*Generated by `python -m tradingagents.strategies.backtest`*\n")
md_path = output_dir / "backtest_report.md"
md_path.write_text("\n".join(lines))
return md_path
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Backtest strategy signals vs historical decisions")
parser.add_argument("--baseline-date", default="2026-03-25", help="Date of baseline (no signals) analysis")
parser.add_argument("--enhanced-date", default="2026-04-14", help="Date of enhanced (with signals) analysis")
parser.add_argument("--eval-date", default="2026-04-16", help="Date to evaluate price movement against")
parser.add_argument("--output", default="./data/backtest", help="Output directory for report")
args = parser.parse_args()
print(f"Loading baseline decisions ({args.baseline_date})...", file=sys.stderr)
baseline_decisions = _load_eval_results(args.baseline_date)
if not baseline_decisions:
baseline_decisions = _load_results(args.baseline_date)
print(f" {len(baseline_decisions)} tickers", file=sys.stderr)
print(f"Loading enhanced decisions ({args.enhanced_date})...", file=sys.stderr)
enhanced_decisions = _load_results(args.enhanced_date)
if not enhanced_decisions:
enhanced_decisions = _load_eval_results(args.enhanced_date)
print(f" {len(enhanced_decisions)} tickers", file=sys.stderr)
if not baseline_decisions and not enhanced_decisions:
print("No decisions found. Ensure eval_results/ or tradingagents/results/ exist.", file=sys.stderr)
sys.exit(1)
print(f"Computing strategy signals + price changes (eval: {args.eval_date})...", file=sys.stderr)
baseline = backtest_run(baseline_decisions, args.baseline_date, args.eval_date, "baseline")
enhanced = backtest_run(enhanced_decisions, args.enhanced_date, args.eval_date, "enhanced")
print(f"Baseline: {len(baseline)} scored, Enhanced: {len(enhanced)} scored", file=sys.stderr)
report_path = generate_report(baseline, enhanced, Path(args.output))
print(f"\nReport: {report_path}", file=sys.stderr)
print(report_path.read_text())
if __name__ == "__main__":
main()