TradingAgents/scripts/build_ml_dataset.py

469 lines
10 KiB
Python

#!/usr/bin/env python3
"""Build ML training dataset from historical OHLCV data.
Fetches price data for a universe of liquid stocks, computes features
locally via stockstats, and applies triple-barrier labels.
Usage:
python scripts/build_ml_dataset.py
python scripts/build_ml_dataset.py --stocks 100 --years 2
python scripts/build_ml_dataset.py --ticker-file data/tickers_top50.txt
"""
from __future__ import annotations
import argparse
import os
import sys
import time
from pathlib import Path
import pandas as pd
# Add project root to path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
from tradingagents.ml.feature_engineering import (
FEATURE_COLUMNS,
MIN_HISTORY_ROWS,
apply_triple_barrier_labels,
compute_features_bulk,
)
from tradingagents.utils.logger import get_logger
logger = get_logger(__name__)
# Default universe: S&P 500 most liquid by volume (top ~200)
# Can be overridden via --ticker-file
DEFAULT_TICKERS = [
# Mega-cap tech
"AAPL",
"MSFT",
"GOOGL",
"AMZN",
"NVDA",
"META",
"TSLA",
"AVGO",
"ORCL",
"CRM",
"AMD",
"INTC",
"CSCO",
"ADBE",
"NFLX",
"QCOM",
"TXN",
"AMAT",
"MU",
"LRCX",
"KLAC",
"MRVL",
"SNPS",
"CDNS",
"PANW",
"CRWD",
"FTNT",
"NOW",
"UBER",
"ABNB",
# Financials
"JPM",
"BAC",
"WFC",
"GS",
"MS",
"C",
"SCHW",
"BLK",
"AXP",
"USB",
"PNC",
"TFC",
"COF",
"BK",
"STT",
"FITB",
"HBAN",
"RF",
"CFG",
"KEY",
# Healthcare
"UNH",
"JNJ",
"LLY",
"PFE",
"ABBV",
"MRK",
"TMO",
"ABT",
"DHR",
"BMY",
"AMGN",
"GILD",
"ISRG",
"VRTX",
"REGN",
"MDT",
"SYK",
"BSX",
"EW",
"ZTS",
# Consumer
"WMT",
"PG",
"KO",
"PEP",
"COST",
"MCD",
"NKE",
"SBUX",
"TGT",
"LOW",
"HD",
"TJX",
"ROST",
"DG",
"DLTR",
"EL",
"CL",
"KMB",
"GIS",
"K",
# Energy
"XOM",
"CVX",
"COP",
"EOG",
"SLB",
"MPC",
"PSX",
"VLO",
"OXY",
"DVN",
"HAL",
"FANG",
"HES",
"BKR",
"KMI",
"WMB",
"OKE",
"ET",
"TRGP",
"LNG",
# Industrials
"CAT",
"DE",
"UNP",
"UPS",
"HON",
"RTX",
"BA",
"LMT",
"GD",
"NOC",
"GE",
"MMM",
"EMR",
"ITW",
"PH",
"ROK",
"ETN",
"SWK",
"CMI",
"PCAR",
# Materials & Utilities
"LIN",
"APD",
"ECL",
"SHW",
"DD",
"NEM",
"FCX",
"VMC",
"MLM",
"NUE",
"NEE",
"DUK",
"SO",
"D",
"AEP",
"EXC",
"SRE",
"XEL",
"WEC",
"ES",
# REITs & Telecom
"AMT",
"PLD",
"CCI",
"EQIX",
"SPG",
"O",
"PSA",
"DLR",
"WELL",
"AVB",
"T",
"VZ",
"TMUS",
"CHTR",
"CMCSA",
# High-volatility / popular retail
"COIN",
"MARA",
"RIOT",
"PLTR",
"SOFI",
"HOOD",
"RBLX",
"SNAP",
"PINS",
"SQ",
"SHOP",
"SE",
"ROKU",
"DKNG",
"PENN",
"WYNN",
"MGM",
"LVS",
"DASH",
"TTD",
# Biotech
"MRNA",
"BNTX",
"BIIB",
"SGEN",
"ALNY",
"BMRN",
"EXAS",
"DXCM",
"HZNP",
"INCY",
]
OUTPUT_DIR = Path("data/ml")
def fetch_ohlcv(ticker: str, start: str, end: str) -> pd.DataFrame:
"""Fetch OHLCV data for a single ticker via yfinance."""
from tradingagents.dataflows.y_finance import download_history
df = download_history(
ticker,
start=start,
end=end,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
if df.empty:
return df
df = df.reset_index()
return df
def get_market_cap(ticker: str) -> float | None:
"""Get current market cap for a ticker (snapshot — used as static feature)."""
try:
import yfinance as yf
info = yf.Ticker(ticker).info
return info.get("marketCap")
except Exception:
return None
def process_ticker(
ticker: str,
start: str,
end: str,
profit_target: float,
stop_loss: float,
max_holding_days: int,
market_cap: float | None = None,
) -> pd.DataFrame | None:
"""Process a single ticker: fetch data, compute features, apply labels."""
try:
ohlcv = fetch_ohlcv(ticker, start, end)
if ohlcv.empty or len(ohlcv) < MIN_HISTORY_ROWS + max_holding_days:
logger.debug(f"{ticker}: insufficient data ({len(ohlcv)} rows), skipping")
return None
# Compute features
features = compute_features_bulk(ohlcv, market_cap=market_cap)
if features.empty:
logger.debug(f"{ticker}: feature computation failed, skipping")
return None
# Compute triple-barrier labels
close = ohlcv.set_index("Date")["Close"] if "Date" in ohlcv.columns else ohlcv["Close"]
if isinstance(close.index, pd.DatetimeIndex):
pass
else:
close.index = pd.to_datetime(close.index)
labels = apply_triple_barrier_labels(
close,
profit_target=profit_target,
stop_loss=stop_loss,
max_holding_days=max_holding_days,
)
# Align features and labels by date
combined = features.join(labels, how="inner")
# Drop rows with NaN features or labels
combined = combined.dropna(subset=["label"] + FEATURE_COLUMNS)
if combined.empty:
logger.debug(f"{ticker}: no valid rows after alignment, skipping")
return None
# Add metadata columns
combined["ticker"] = ticker
combined["date"] = combined.index
logger.info(
f"{ticker}: {len(combined)} samples "
f"(WIN={int((combined['label'] == 1).sum())}, "
f"LOSS={int((combined['label'] == -1).sum())}, "
f"TIMEOUT={int((combined['label'] == 0).sum())})"
)
return combined
except Exception as e:
logger.warning(f"{ticker}: error processing — {e}")
return None
def build_dataset(
tickers: list[str],
start: str = "2022-01-01",
end: str = "2025-12-31",
profit_target: float = 0.05,
stop_loss: float = 0.03,
max_holding_days: int = 7,
) -> pd.DataFrame:
"""Build the full training dataset across all tickers."""
all_data = []
total = len(tickers)
logger.info(f"Building ML dataset: {total} tickers, {start} to {end}")
logger.info(
f"Triple-barrier: +{profit_target*100:.0f}% profit, "
f"-{stop_loss*100:.0f}% stop, {max_holding_days}d timeout"
)
# Batch-fetch market caps
logger.info("Fetching market caps...")
market_caps = {}
for ticker in tickers:
market_caps[ticker] = get_market_cap(ticker)
time.sleep(0.05) # rate limit courtesy
for i, ticker in enumerate(tickers):
logger.info(f"[{i+1}/{total}] Processing {ticker}...")
result = process_ticker(
ticker=ticker,
start=start,
end=end,
profit_target=profit_target,
stop_loss=stop_loss,
max_holding_days=max_holding_days,
market_cap=market_caps.get(ticker),
)
if result is not None:
all_data.append(result)
# Brief pause between tickers to be polite to yfinance
if (i + 1) % 50 == 0:
logger.info(f"Progress: {i+1}/{total} tickers processed, pausing 2s...")
time.sleep(2)
if not all_data:
logger.error("No data collected — check tickers and date range")
return pd.DataFrame()
dataset = pd.concat(all_data, ignore_index=True)
logger.info(f"\n{'='*60}")
logger.info(f"Dataset built: {len(dataset)} total samples from {len(all_data)} tickers")
logger.info("Label distribution:")
logger.info(
f" WIN (+1): {int((dataset['label'] == 1).sum()):>7} ({(dataset['label'] == 1).mean()*100:.1f}%)"
)
logger.info(
f" LOSS (-1): {int((dataset['label'] == -1).sum()):>7} ({(dataset['label'] == -1).mean()*100:.1f}%)"
)
logger.info(
f" TIMEOUT: {int((dataset['label'] == 0).sum()):>7} ({(dataset['label'] == 0).mean()*100:.1f}%)"
)
logger.info(f"Features: {len(FEATURE_COLUMNS)}")
logger.info(f"{'='*60}")
return dataset
def main():
parser = argparse.ArgumentParser(description="Build ML training dataset")
parser.add_argument(
"--stocks", type=int, default=None, help="Limit to N stocks from default universe"
)
parser.add_argument(
"--ticker-file", type=str, default=None, help="File with tickers (one per line)"
)
parser.add_argument("--start", type=str, default="2022-01-01", help="Start date (YYYY-MM-DD)")
parser.add_argument("--end", type=str, default="2025-12-31", help="End date (YYYY-MM-DD)")
parser.add_argument(
"--profit-target", type=float, default=0.05, help="Profit target fraction (default: 0.05)"
)
parser.add_argument(
"--stop-loss", type=float, default=0.03, help="Stop loss fraction (default: 0.03)"
)
parser.add_argument("--holding-days", type=int, default=7, help="Max holding days (default: 7)")
parser.add_argument("--output", type=str, default=None, help="Output parquet path")
args = parser.parse_args()
# Determine ticker list
if args.ticker_file:
with open(args.ticker_file) as f:
tickers = [
line.strip().upper() for line in f if line.strip() and not line.startswith("#")
]
logger.info(f"Loaded {len(tickers)} tickers from {args.ticker_file}")
else:
tickers = DEFAULT_TICKERS
if args.stocks:
tickers = tickers[: args.stocks]
# Build dataset
dataset = build_dataset(
tickers=tickers,
start=args.start,
end=args.end,
profit_target=args.profit_target,
stop_loss=args.stop_loss,
max_holding_days=args.holding_days,
)
if dataset.empty:
logger.error("Empty dataset — aborting")
sys.exit(1)
# Save
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
output_path = args.output or str(OUTPUT_DIR / "training_dataset.parquet")
dataset.to_parquet(output_path, index=False)
logger.info(f"Saved dataset to {output_path} ({os.path.getsize(output_path) / 1e6:.1f} MB)")
if __name__ == "__main__":
main()