TradingAgents/tradingagents/eval/walk_forward.py

156 lines
5.7 KiB
Python

from __future__ import annotations
import argparse
import json
from collections import defaultdict
from copy import deepcopy
from datetime import datetime, timedelta
from pathlib import Path
from typing import Iterable
import pandas as pd
import yfinance as yf
from tradingagents.default_config import DEFAULT_CONFIG
from tradingagents.graph.trading_graph import TradingAgentsGraph
RATING_TO_EXPOSURE = {
"BUY": 1.0,
"OVERWEIGHT": 0.5,
"HOLD": 0.0,
"UNDERWEIGHT": -0.5,
"SELL": -1.0,
"NO_TRADE": 0.0,
}
def _fetch_forward_return(symbol: str, trade_date: str, holding_period: int) -> float | None:
start_dt = datetime.strptime(trade_date, "%Y-%m-%d")
end_dt = start_dt + timedelta(days=max(holding_period * 3, 10))
history = yf.Ticker(symbol).history(start=trade_date, end=end_dt.strftime("%Y-%m-%d"))
if history.empty or "Close" not in history:
return None
closes = history["Close"].dropna()
if len(closes) < 2:
return None
entry_price = float(closes.iloc[0])
exit_index = min(holding_period, len(closes) - 1)
exit_price = float(closes.iloc[exit_index])
return (exit_price / entry_price) - 1.0
def _compute_max_drawdown(returns: Iterable[float]) -> float:
cumulative = pd.Series(list(returns)).fillna(0.0).add(1.0).cumprod()
running_max = cumulative.cummax()
drawdown = (cumulative / running_max) - 1.0
return float(drawdown.min()) if not drawdown.empty else 0.0
def run_walk_forward_evaluation(
symbols: list[str],
trade_dates: list[str],
*,
holding_period: int = 5,
benchmark_symbol: str = "SPY",
graph_config: dict | None = None,
selected_analysts: list[str] | None = None,
enable_reflection: bool = False,
) -> dict:
config = deepcopy(graph_config or DEFAULT_CONFIG)
graph = TradingAgentsGraph(
config=config,
selected_analysts=selected_analysts or ["market", "social", "news", "fundamentals"],
)
records: list[dict] = []
previous_exposure = 0.0
for trade_date in trade_dates:
benchmark_return = _fetch_forward_return(benchmark_symbol, trade_date, holding_period)
for symbol in symbols:
final_state, rating = graph.propagate(symbol, trade_date)
asset_return = _fetch_forward_return(final_state["company_of_interest"], trade_date, holding_period)
if asset_return is None:
continue
exposure = RATING_TO_EXPOSURE.get(rating, 0.0)
strategy_return = exposure * asset_return
turnover = abs(exposure - previous_exposure)
previous_exposure = exposure
if enable_reflection:
graph.reflect_and_remember(strategy_return)
country = (final_state.get("instrument_profile") or {}).get("country", "UNKNOWN")
records.append(
{
"symbol": final_state["company_of_interest"],
"input_instrument": final_state.get("input_instrument", symbol),
"country": country,
"trade_date": trade_date,
"rating": rating,
"asset_return": asset_return,
"strategy_return": strategy_return,
"benchmark_return": benchmark_return,
"excess_return": None if benchmark_return is None else strategy_return - benchmark_return,
"turnover": turnover,
}
)
if not records:
return {"records": [], "metrics": {}}
df = pd.DataFrame(records)
bucket_metrics = df.groupby("rating")["asset_return"].mean().to_dict()
region_metrics = (
df.groupby("country")["strategy_return"]
.agg(["mean", "count"])
.rename(columns={"mean": "avg_strategy_return"})
.to_dict(orient="index")
)
metrics = {
"hit_rate": float((df["strategy_return"] > 0).mean()),
"forward_return_by_rating_bucket": bucket_metrics,
"turnover": float(df["turnover"].mean()),
"max_drawdown": _compute_max_drawdown(df["strategy_return"].tolist()),
"benchmark_excess_return": float(df["excess_return"].dropna().mean()) if df["excess_return"].notna().any() else None,
"abstain_frequency": float((df["rating"] == "NO_TRADE").mean()),
"region_split_metrics": region_metrics,
}
return {"records": records, "metrics": metrics}
def main():
parser = argparse.ArgumentParser(description="Run a simple walk-forward evaluation for TradingAgents.")
parser.add_argument("--symbols", nargs="+", required=True, help="Instrument inputs, such as AAPL or 005930")
parser.add_argument("--trade-dates", nargs="+", required=True, help="Trade dates in YYYY-MM-DD format")
parser.add_argument("--holding-period", type=int, default=5, help="Forward holding period in trading days")
parser.add_argument("--benchmark", default="SPY", help="Benchmark ticker for excess-return comparison")
parser.add_argument("--enable-reflection", action="store_true", help="Call reflect_and_remember after each evaluated trade")
parser.add_argument("--output", default=None, help="Optional JSON output path")
args = parser.parse_args()
result = run_walk_forward_evaluation(
symbols=args.symbols,
trade_dates=args.trade_dates,
holding_period=args.holding_period,
benchmark_symbol=args.benchmark,
enable_reflection=args.enable_reflection,
)
rendered = json.dumps(result, indent=2, ensure_ascii=False)
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(rendered, encoding="utf-8")
else:
print(rendered)
if __name__ == "__main__":
main()