156 lines
5.7 KiB
Python
156 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import pandas as pd
|
|
import yfinance as yf
|
|
|
|
from tradingagents.default_config import DEFAULT_CONFIG
|
|
from tradingagents.graph.trading_graph import TradingAgentsGraph
|
|
|
|
|
|
RATING_TO_EXPOSURE = {
|
|
"BUY": 1.0,
|
|
"OVERWEIGHT": 0.5,
|
|
"HOLD": 0.0,
|
|
"UNDERWEIGHT": -0.5,
|
|
"SELL": -1.0,
|
|
"NO_TRADE": 0.0,
|
|
}
|
|
|
|
|
|
def _fetch_forward_return(symbol: str, trade_date: str, holding_period: int) -> float | None:
|
|
start_dt = datetime.strptime(trade_date, "%Y-%m-%d")
|
|
end_dt = start_dt + timedelta(days=max(holding_period * 3, 10))
|
|
history = yf.Ticker(symbol).history(start=trade_date, end=end_dt.strftime("%Y-%m-%d"))
|
|
if history.empty or "Close" not in history:
|
|
return None
|
|
|
|
closes = history["Close"].dropna()
|
|
if len(closes) < 2:
|
|
return None
|
|
|
|
entry_price = float(closes.iloc[0])
|
|
exit_index = min(holding_period, len(closes) - 1)
|
|
exit_price = float(closes.iloc[exit_index])
|
|
return (exit_price / entry_price) - 1.0
|
|
|
|
|
|
def _compute_max_drawdown(returns: Iterable[float]) -> float:
|
|
cumulative = pd.Series(list(returns)).fillna(0.0).add(1.0).cumprod()
|
|
running_max = cumulative.cummax()
|
|
drawdown = (cumulative / running_max) - 1.0
|
|
return float(drawdown.min()) if not drawdown.empty else 0.0
|
|
|
|
|
|
def run_walk_forward_evaluation(
|
|
symbols: list[str],
|
|
trade_dates: list[str],
|
|
*,
|
|
holding_period: int = 5,
|
|
benchmark_symbol: str = "SPY",
|
|
graph_config: dict | None = None,
|
|
selected_analysts: list[str] | None = None,
|
|
enable_reflection: bool = False,
|
|
) -> dict:
|
|
config = deepcopy(graph_config or DEFAULT_CONFIG)
|
|
graph = TradingAgentsGraph(
|
|
config=config,
|
|
selected_analysts=selected_analysts or ["market", "social", "news", "fundamentals"],
|
|
)
|
|
|
|
records: list[dict] = []
|
|
previous_exposure = 0.0
|
|
|
|
for trade_date in trade_dates:
|
|
benchmark_return = _fetch_forward_return(benchmark_symbol, trade_date, holding_period)
|
|
for symbol in symbols:
|
|
final_state, rating = graph.propagate(symbol, trade_date)
|
|
asset_return = _fetch_forward_return(final_state["company_of_interest"], trade_date, holding_period)
|
|
if asset_return is None:
|
|
continue
|
|
|
|
exposure = RATING_TO_EXPOSURE.get(rating, 0.0)
|
|
strategy_return = exposure * asset_return
|
|
turnover = abs(exposure - previous_exposure)
|
|
previous_exposure = exposure
|
|
|
|
if enable_reflection:
|
|
graph.reflect_and_remember(strategy_return)
|
|
|
|
country = (final_state.get("instrument_profile") or {}).get("country", "UNKNOWN")
|
|
records.append(
|
|
{
|
|
"symbol": final_state["company_of_interest"],
|
|
"input_instrument": final_state.get("input_instrument", symbol),
|
|
"country": country,
|
|
"trade_date": trade_date,
|
|
"rating": rating,
|
|
"asset_return": asset_return,
|
|
"strategy_return": strategy_return,
|
|
"benchmark_return": benchmark_return,
|
|
"excess_return": None if benchmark_return is None else strategy_return - benchmark_return,
|
|
"turnover": turnover,
|
|
}
|
|
)
|
|
|
|
if not records:
|
|
return {"records": [], "metrics": {}}
|
|
|
|
df = pd.DataFrame(records)
|
|
bucket_metrics = df.groupby("rating")["asset_return"].mean().to_dict()
|
|
region_metrics = (
|
|
df.groupby("country")["strategy_return"]
|
|
.agg(["mean", "count"])
|
|
.rename(columns={"mean": "avg_strategy_return"})
|
|
.to_dict(orient="index")
|
|
)
|
|
|
|
metrics = {
|
|
"hit_rate": float((df["strategy_return"] > 0).mean()),
|
|
"forward_return_by_rating_bucket": bucket_metrics,
|
|
"turnover": float(df["turnover"].mean()),
|
|
"max_drawdown": _compute_max_drawdown(df["strategy_return"].tolist()),
|
|
"benchmark_excess_return": float(df["excess_return"].dropna().mean()) if df["excess_return"].notna().any() else None,
|
|
"abstain_frequency": float((df["rating"] == "NO_TRADE").mean()),
|
|
"region_split_metrics": region_metrics,
|
|
}
|
|
return {"records": records, "metrics": metrics}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Run a simple walk-forward evaluation for TradingAgents.")
|
|
parser.add_argument("--symbols", nargs="+", required=True, help="Instrument inputs, such as AAPL or 005930")
|
|
parser.add_argument("--trade-dates", nargs="+", required=True, help="Trade dates in YYYY-MM-DD format")
|
|
parser.add_argument("--holding-period", type=int, default=5, help="Forward holding period in trading days")
|
|
parser.add_argument("--benchmark", default="SPY", help="Benchmark ticker for excess-return comparison")
|
|
parser.add_argument("--enable-reflection", action="store_true", help="Call reflect_and_remember after each evaluated trade")
|
|
parser.add_argument("--output", default=None, help="Optional JSON output path")
|
|
args = parser.parse_args()
|
|
|
|
result = run_walk_forward_evaluation(
|
|
symbols=args.symbols,
|
|
trade_dates=args.trade_dates,
|
|
holding_period=args.holding_period,
|
|
benchmark_symbol=args.benchmark,
|
|
enable_reflection=args.enable_reflection,
|
|
)
|
|
|
|
rendered = json.dumps(result, indent=2, ensure_ascii=False)
|
|
if args.output:
|
|
output_path = Path(args.output)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
output_path.write_text(rendered, encoding="utf-8")
|
|
else:
|
|
print(rendered)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|