408 lines
14 KiB
Python
408 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""Train ML model on the generated dataset.
|
|
|
|
Supports TabPFN (recommended, requires GPU or API) and LightGBM (fallback).
|
|
Uses time-based train/validation split to prevent data leakage.
|
|
|
|
Usage:
|
|
python scripts/train_ml_model.py
|
|
python scripts/train_ml_model.py --model lightgbm
|
|
python scripts/train_ml_model.py --model tabpfn --dataset data/ml/training_dataset.parquet
|
|
python scripts/train_ml_model.py --max-train-samples 5000
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sklearn.metrics import (
|
|
accuracy_score,
|
|
classification_report,
|
|
confusion_matrix,
|
|
)
|
|
|
|
# Add project root to path
|
|
project_root = str(Path(__file__).resolve().parent.parent)
|
|
if project_root not in sys.path:
|
|
sys.path.insert(0, project_root)
|
|
|
|
from tradingagents.ml.feature_engineering import FEATURE_COLUMNS
|
|
from tradingagents.ml.predictor import LGBMWrapper, MLPredictor
|
|
from tradingagents.utils.logger import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
DATA_DIR = Path("data/ml")
|
|
LABEL_NAMES = {-1: "LOSS", 0: "TIMEOUT", 1: "WIN"}
|
|
|
|
|
|
def load_dataset(path: str) -> pd.DataFrame:
|
|
"""Load and validate the training dataset."""
|
|
df = pd.read_parquet(path)
|
|
logger.info(f"Loaded {len(df)} samples from {path}")
|
|
|
|
# Validate columns
|
|
missing = [c for c in FEATURE_COLUMNS if c not in df.columns]
|
|
if missing:
|
|
raise ValueError(f"Missing feature columns: {missing}")
|
|
if "label" not in df.columns:
|
|
raise ValueError("Missing 'label' column")
|
|
if "date" not in df.columns:
|
|
raise ValueError("Missing 'date' column")
|
|
|
|
# Show label distribution
|
|
for label, name in LABEL_NAMES.items():
|
|
count = (df["label"] == label).sum()
|
|
pct = count / len(df) * 100
|
|
logger.info(f" {name:>7} ({label:+d}): {count:>7} ({pct:.1f}%)")
|
|
|
|
return df
|
|
|
|
|
|
def time_split(
|
|
df: pd.DataFrame,
|
|
val_start: str = "2024-07-01",
|
|
max_train_samples: int | None = None,
|
|
) -> tuple:
|
|
"""Split dataset by time — train on older data, validate on newer."""
|
|
df["date"] = pd.to_datetime(df["date"])
|
|
val_start_dt = pd.Timestamp(val_start)
|
|
|
|
train = df[df["date"] < val_start_dt].copy()
|
|
val = df[df["date"] >= val_start_dt].copy()
|
|
|
|
if max_train_samples is not None and len(train) > max_train_samples:
|
|
train = train.sort_values("date").tail(max_train_samples)
|
|
logger.info(
|
|
f"Limiting training samples to most recent {max_train_samples} " f"before {val_start}"
|
|
)
|
|
|
|
logger.info(f"Time-based split at {val_start}:")
|
|
logger.info(
|
|
f" Train: {len(train)} samples ({train['date'].min().date()} to {train['date'].max().date()})"
|
|
)
|
|
logger.info(
|
|
f" Val: {len(val)} samples ({val['date'].min().date()} to {val['date'].max().date()})"
|
|
)
|
|
|
|
X_train = train[FEATURE_COLUMNS].values
|
|
y_train = train["label"].values.astype(int)
|
|
X_val = val[FEATURE_COLUMNS].values
|
|
y_val = val["label"].values.astype(int)
|
|
|
|
return X_train, y_train, X_val, y_val
|
|
|
|
|
|
def train_tabpfn(X_train, y_train, X_val, y_val):
|
|
"""Train using TabPFN foundation model."""
|
|
try:
|
|
from tabpfn import TabPFNClassifier
|
|
except ImportError:
|
|
logger.error("TabPFN not installed. Install with: pip install tabpfn")
|
|
logger.error("Falling back to LightGBM...")
|
|
return train_lightgbm(X_train, y_train, X_val, y_val)
|
|
|
|
logger.info("Training TabPFN classifier...")
|
|
|
|
# TabPFN handles NaN values natively
|
|
# For large datasets, subsample training data (TabPFN works best with <10K samples)
|
|
max_train = 10_000
|
|
if len(X_train) > max_train:
|
|
logger.info(f"Subsampling training data: {len(X_train)} → {max_train}")
|
|
idx = np.random.RandomState(42).choice(len(X_train), max_train, replace=False)
|
|
X_train_sub = X_train[idx]
|
|
y_train_sub = y_train[idx]
|
|
else:
|
|
X_train_sub = X_train
|
|
y_train_sub = y_train
|
|
|
|
try:
|
|
clf = TabPFNClassifier()
|
|
clf.fit(X_train_sub, y_train_sub)
|
|
return clf, "tabpfn"
|
|
except Exception as e:
|
|
logger.error(f"TabPFN training failed: {e}")
|
|
logger.error("Falling back to LightGBM...")
|
|
return train_lightgbm(X_train, y_train, X_val, y_val)
|
|
|
|
|
|
def train_lightgbm(X_train, y_train, X_val, y_val):
|
|
"""Train using LightGBM (fallback when TabPFN unavailable)."""
|
|
try:
|
|
import lightgbm as lgb
|
|
except ImportError:
|
|
logger.error("LightGBM not installed. Install with: pip install lightgbm")
|
|
sys.exit(1)
|
|
|
|
logger.info("Training LightGBM classifier...")
|
|
|
|
# Remap labels: {-1, 0, 1} → {0, 1, 2} for LightGBM
|
|
y_train_mapped = y_train + 1 # -1→0, 0→1, 1→2
|
|
y_val_mapped = y_val + 1
|
|
|
|
# Compute class weights to handle imbalanced labels
|
|
from collections import Counter
|
|
|
|
class_counts = Counter(y_train_mapped)
|
|
total = len(y_train_mapped)
|
|
n_classes = len(class_counts)
|
|
class_weight = {c: total / (n_classes * count) for c, count in class_counts.items()}
|
|
sample_weights = np.array([class_weight[y] for y in y_train_mapped])
|
|
|
|
train_data = lgb.Dataset(
|
|
X_train, label=y_train_mapped, weight=sample_weights, feature_name=FEATURE_COLUMNS
|
|
)
|
|
val_data = lgb.Dataset(
|
|
X_val, label=y_val_mapped, feature_name=FEATURE_COLUMNS, reference=train_data
|
|
)
|
|
|
|
params = {
|
|
"objective": "multiclass",
|
|
"num_class": 3,
|
|
"metric": "multi_logloss",
|
|
# Lower LR + more rounds = smoother learning on noisy data
|
|
"learning_rate": 0.01,
|
|
# More capacity to find feature interactions
|
|
"num_leaves": 63,
|
|
"max_depth": 8,
|
|
"min_child_samples": 100,
|
|
# Aggressive subsampling to reduce overfitting on noise
|
|
"subsample": 0.7,
|
|
"subsample_freq": 1,
|
|
"colsample_bytree": 0.7,
|
|
# Stronger regularization for financial data
|
|
"reg_alpha": 1.0,
|
|
"reg_lambda": 1.0,
|
|
"min_gain_to_split": 0.01,
|
|
"path_smooth": 1.0,
|
|
"verbose": -1,
|
|
"seed": 42,
|
|
}
|
|
|
|
callbacks = [
|
|
lgb.log_evaluation(period=100),
|
|
lgb.early_stopping(stopping_rounds=100),
|
|
]
|
|
|
|
booster = lgb.train(
|
|
params,
|
|
train_data,
|
|
num_boost_round=2000,
|
|
valid_sets=[val_data],
|
|
callbacks=callbacks,
|
|
)
|
|
|
|
# Wrap in sklearn-compatible interface
|
|
clf = LGBMWrapper(booster, y_train)
|
|
|
|
return clf, "lightgbm"
|
|
|
|
|
|
def evaluate(model, X_val, y_val, model_type: str) -> dict:
|
|
"""Evaluate model and return metrics dict."""
|
|
if isinstance(X_val, np.ndarray):
|
|
X_df = pd.DataFrame(X_val, columns=FEATURE_COLUMNS)
|
|
else:
|
|
X_df = X_val
|
|
|
|
y_pred = model.predict(X_df)
|
|
probas = model.predict_proba(X_df)
|
|
|
|
accuracy = accuracy_score(y_val, y_pred)
|
|
report = classification_report(
|
|
y_val,
|
|
y_pred,
|
|
target_names=["LOSS (-1)", "TIMEOUT (0)", "WIN (+1)"],
|
|
output_dict=True,
|
|
)
|
|
cm = confusion_matrix(y_val, y_pred)
|
|
|
|
# Win-class specific metrics
|
|
win_mask = y_val == 1
|
|
if win_mask.sum() > 0:
|
|
win_probs = probas[win_mask]
|
|
win_col_idx = list(model.classes_).index(1)
|
|
avg_win_prob_for_actual_wins = float(win_probs[:, win_col_idx].mean())
|
|
else:
|
|
avg_win_prob_for_actual_wins = 0.0
|
|
|
|
# High-confidence win precision
|
|
win_col_idx = list(model.classes_).index(1)
|
|
high_conf_mask = probas[:, win_col_idx] >= 0.6
|
|
if high_conf_mask.sum() > 0:
|
|
high_conf_precision = float((y_val[high_conf_mask] == 1).mean())
|
|
high_conf_count = int(high_conf_mask.sum())
|
|
else:
|
|
high_conf_precision = 0.0
|
|
high_conf_count = 0
|
|
|
|
# Calibration analysis: do higher P(WIN) quintiles actually win more?
|
|
win_probs_all = probas[:, win_col_idx]
|
|
quintile_labels = pd.qcut(win_probs_all, q=5, labels=False, duplicates="drop")
|
|
calibration = {}
|
|
for q in sorted(set(quintile_labels)):
|
|
mask = quintile_labels == q
|
|
q_probs = win_probs_all[mask]
|
|
q_actual_win_rate = float((y_val[mask] == 1).mean())
|
|
q_actual_loss_rate = float((y_val[mask] == -1).mean())
|
|
calibration[f"Q{q+1}"] = {
|
|
"mean_predicted_win_prob": round(float(q_probs.mean()), 4),
|
|
"actual_win_rate": round(q_actual_win_rate, 4),
|
|
"actual_loss_rate": round(q_actual_loss_rate, 4),
|
|
"count": int(mask.sum()),
|
|
}
|
|
|
|
# Top decile (top 10% by P(WIN)) — most actionable metric
|
|
top_decile_threshold = np.percentile(win_probs_all, 90)
|
|
top_decile_mask = win_probs_all >= top_decile_threshold
|
|
top_decile_win_rate = (
|
|
float((y_val[top_decile_mask] == 1).mean()) if top_decile_mask.sum() > 0 else 0.0
|
|
)
|
|
top_decile_loss_rate = (
|
|
float((y_val[top_decile_mask] == -1).mean()) if top_decile_mask.sum() > 0 else 0.0
|
|
)
|
|
|
|
metrics = {
|
|
"model_type": model_type,
|
|
"accuracy": round(accuracy, 4),
|
|
"per_class": {
|
|
k: {kk: round(vv, 4) for kk, vv in v.items()}
|
|
for k, v in report.items()
|
|
if isinstance(v, dict)
|
|
},
|
|
"confusion_matrix": cm.tolist(),
|
|
"avg_win_prob_for_actual_wins": round(avg_win_prob_for_actual_wins, 4),
|
|
"high_confidence_win_precision": round(high_conf_precision, 4),
|
|
"high_confidence_win_count": high_conf_count,
|
|
"calibration_quintiles": calibration,
|
|
"top_decile_win_rate": round(top_decile_win_rate, 4),
|
|
"top_decile_loss_rate": round(top_decile_loss_rate, 4),
|
|
"top_decile_threshold": round(float(top_decile_threshold), 4),
|
|
"top_decile_count": int(top_decile_mask.sum()),
|
|
"val_samples": len(y_val),
|
|
}
|
|
|
|
# Print summary
|
|
logger.info(f"\n{'='*60}")
|
|
logger.info(f"Model: {model_type}")
|
|
logger.info(f"Overall Accuracy: {accuracy:.1%}")
|
|
logger.info("\nPer-class metrics:")
|
|
logger.info(f"{'':>15} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}")
|
|
for label, name in [(-1, "LOSS"), (0, "TIMEOUT"), (1, "WIN")]:
|
|
key = f"{name} ({label:+d})"
|
|
if key in report:
|
|
r = report[key]
|
|
logger.info(
|
|
f"{name:>15} {r['precision']:>10.3f} {r['recall']:>10.3f} {r['f1-score']:>10.3f} {r['support']:>10.0f}"
|
|
)
|
|
|
|
logger.info("\nConfusion Matrix (rows=actual, cols=predicted):")
|
|
logger.info(f"{'':>10} {'LOSS':>8} {'TIMEOUT':>8} {'WIN':>8}")
|
|
for i, name in enumerate(["LOSS", "TIMEOUT", "WIN"]):
|
|
logger.info(f"{name:>10} {cm[i][0]:>8} {cm[i][1]:>8} {cm[i][2]:>8}")
|
|
|
|
logger.info("\nWin-class insights:")
|
|
logger.info(f" Avg P(WIN) for actual winners: {avg_win_prob_for_actual_wins:.1%}")
|
|
logger.info(
|
|
f" High-confidence (>60%) precision: {high_conf_precision:.1%} ({high_conf_count} samples)"
|
|
)
|
|
|
|
logger.info("\nCalibration (does higher P(WIN) = more actual wins?):")
|
|
logger.info(
|
|
f"{'Quintile':>10} {'Avg P(WIN)':>12} {'Actual WIN%':>12} {'Actual LOSS%':>13} {'Count':>8}"
|
|
)
|
|
for q_name, q_data in calibration.items():
|
|
logger.info(
|
|
f"{q_name:>10} {q_data['mean_predicted_win_prob']:>12.1%} "
|
|
f"{q_data['actual_win_rate']:>12.1%} {q_data['actual_loss_rate']:>13.1%} "
|
|
f"{q_data['count']:>8}"
|
|
)
|
|
|
|
logger.info("\nTop decile (top 10% by P(WIN)):")
|
|
logger.info(f" Threshold: P(WIN) >= {top_decile_threshold:.1%}")
|
|
logger.info(
|
|
f" Actual win rate: {top_decile_win_rate:.1%} ({int(top_decile_mask.sum())} samples)"
|
|
)
|
|
logger.info(f" Actual loss rate: {top_decile_loss_rate:.1%}")
|
|
baseline_win = float((y_val == 1).mean())
|
|
logger.info(f" Baseline win rate: {baseline_win:.1%}")
|
|
if baseline_win > 0:
|
|
logger.info(f" Lift over baseline: {top_decile_win_rate / baseline_win:.2f}x")
|
|
logger.info(f"{'='*60}")
|
|
|
|
return metrics
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train ML model for win probability")
|
|
parser.add_argument("--dataset", type=str, default="data/ml/training_dataset.parquet")
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
choices=["tabpfn", "lightgbm", "auto"],
|
|
default="auto",
|
|
help="Model type (auto tries TabPFN first, falls back to LightGBM)",
|
|
)
|
|
parser.add_argument(
|
|
"--val-start",
|
|
type=str,
|
|
default="2024-07-01",
|
|
help="Validation split date (default: 2024-07-01)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-train-samples",
|
|
type=int,
|
|
default=None,
|
|
help="Limit training samples to the most recent N before val-start",
|
|
)
|
|
parser.add_argument("--output-dir", type=str, default="data/ml")
|
|
args = parser.parse_args()
|
|
|
|
if args.max_train_samples is not None and args.max_train_samples <= 0:
|
|
logger.error("--max-train-samples must be a positive integer")
|
|
sys.exit(1)
|
|
|
|
# Load dataset
|
|
df = load_dataset(args.dataset)
|
|
|
|
# Split
|
|
X_train, y_train, X_val, y_val = time_split(
|
|
df,
|
|
val_start=args.val_start,
|
|
max_train_samples=args.max_train_samples,
|
|
)
|
|
|
|
if len(X_val) == 0:
|
|
logger.error(f"No validation data after {args.val_start} — adjust --val-start")
|
|
sys.exit(1)
|
|
|
|
# Train
|
|
if args.model == "tabpfn" or args.model == "auto":
|
|
model, model_type = train_tabpfn(X_train, y_train, X_val, y_val)
|
|
else:
|
|
model, model_type = train_lightgbm(X_train, y_train, X_val, y_val)
|
|
|
|
# Evaluate
|
|
metrics = evaluate(model, X_val, y_val, model_type)
|
|
|
|
# Save model
|
|
predictor = MLPredictor(model=model, feature_columns=FEATURE_COLUMNS, model_type=model_type)
|
|
model_path = predictor.save(args.output_dir)
|
|
logger.info(f"Model saved to {model_path}")
|
|
|
|
# Save metrics
|
|
metrics_path = os.path.join(args.output_dir, "metrics.json")
|
|
with open(metrics_path, "w") as f:
|
|
json.dump(metrics, f, indent=2)
|
|
logger.info(f"Metrics saved to {metrics_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|