#!/usr/bin/env python3 """Train ML model on the generated dataset. Supports TabPFN (recommended, requires GPU or API) and LightGBM (fallback). Uses time-based train/validation split to prevent data leakage. Usage: python scripts/train_ml_model.py python scripts/train_ml_model.py --model lightgbm python scripts/train_ml_model.py --model tabpfn --dataset data/ml/training_dataset.parquet python scripts/train_ml_model.py --max-train-samples 5000 """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path import numpy as np import pandas as pd from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, ) # Add project root to path project_root = str(Path(__file__).resolve().parent.parent) if project_root not in sys.path: sys.path.insert(0, project_root) from tradingagents.ml.feature_engineering import FEATURE_COLUMNS from tradingagents.ml.predictor import LGBMWrapper, MLPredictor from tradingagents.utils.logger import get_logger logger = get_logger(__name__) DATA_DIR = Path("data/ml") LABEL_NAMES = {-1: "LOSS", 0: "TIMEOUT", 1: "WIN"} def load_dataset(path: str) -> pd.DataFrame: """Load and validate the training dataset.""" df = pd.read_parquet(path) logger.info(f"Loaded {len(df)} samples from {path}") # Validate columns missing = [c for c in FEATURE_COLUMNS if c not in df.columns] if missing: raise ValueError(f"Missing feature columns: {missing}") if "label" not in df.columns: raise ValueError("Missing 'label' column") if "date" not in df.columns: raise ValueError("Missing 'date' column") # Show label distribution for label, name in LABEL_NAMES.items(): count = (df["label"] == label).sum() pct = count / len(df) * 100 logger.info(f" {name:>7} ({label:+d}): {count:>7} ({pct:.1f}%)") return df def time_split( df: pd.DataFrame, val_start: str = "2024-07-01", max_train_samples: int | None = None, ) -> tuple: """Split dataset by time — train on older data, validate on newer.""" df["date"] = pd.to_datetime(df["date"]) val_start_dt = pd.Timestamp(val_start) train = df[df["date"] < val_start_dt].copy() val = df[df["date"] >= val_start_dt].copy() if max_train_samples is not None and len(train) > max_train_samples: train = train.sort_values("date").tail(max_train_samples) logger.info( f"Limiting training samples to most recent {max_train_samples} " f"before {val_start}" ) logger.info(f"Time-based split at {val_start}:") logger.info( f" Train: {len(train)} samples ({train['date'].min().date()} to {train['date'].max().date()})" ) logger.info( f" Val: {len(val)} samples ({val['date'].min().date()} to {val['date'].max().date()})" ) X_train = train[FEATURE_COLUMNS].values y_train = train["label"].values.astype(int) X_val = val[FEATURE_COLUMNS].values y_val = val["label"].values.astype(int) return X_train, y_train, X_val, y_val def train_tabpfn(X_train, y_train, X_val, y_val): """Train using TabPFN foundation model.""" try: from tabpfn import TabPFNClassifier except ImportError: logger.error("TabPFN not installed. Install with: pip install tabpfn") logger.error("Falling back to LightGBM...") return train_lightgbm(X_train, y_train, X_val, y_val) logger.info("Training TabPFN classifier...") # TabPFN handles NaN values natively # For large datasets, subsample training data (TabPFN works best with <10K samples) max_train = 10_000 if len(X_train) > max_train: logger.info(f"Subsampling training data: {len(X_train)} → {max_train}") idx = np.random.RandomState(42).choice(len(X_train), max_train, replace=False) X_train_sub = X_train[idx] y_train_sub = y_train[idx] else: X_train_sub = X_train y_train_sub = y_train try: clf = TabPFNClassifier() clf.fit(X_train_sub, y_train_sub) return clf, "tabpfn" except Exception as e: logger.error(f"TabPFN training failed: {e}") logger.error("Falling back to LightGBM...") return train_lightgbm(X_train, y_train, X_val, y_val) def train_lightgbm(X_train, y_train, X_val, y_val): """Train using LightGBM (fallback when TabPFN unavailable).""" try: import lightgbm as lgb except ImportError: logger.error("LightGBM not installed. Install with: pip install lightgbm") sys.exit(1) logger.info("Training LightGBM classifier...") # Remap labels: {-1, 0, 1} → {0, 1, 2} for LightGBM y_train_mapped = y_train + 1 # -1→0, 0→1, 1→2 y_val_mapped = y_val + 1 # Compute class weights to handle imbalanced labels from collections import Counter class_counts = Counter(y_train_mapped) total = len(y_train_mapped) n_classes = len(class_counts) class_weight = {c: total / (n_classes * count) for c, count in class_counts.items()} sample_weights = np.array([class_weight[y] for y in y_train_mapped]) train_data = lgb.Dataset( X_train, label=y_train_mapped, weight=sample_weights, feature_name=FEATURE_COLUMNS ) val_data = lgb.Dataset( X_val, label=y_val_mapped, feature_name=FEATURE_COLUMNS, reference=train_data ) params = { "objective": "multiclass", "num_class": 3, "metric": "multi_logloss", # Lower LR + more rounds = smoother learning on noisy data "learning_rate": 0.01, # More capacity to find feature interactions "num_leaves": 63, "max_depth": 8, "min_child_samples": 100, # Aggressive subsampling to reduce overfitting on noise "subsample": 0.7, "subsample_freq": 1, "colsample_bytree": 0.7, # Stronger regularization for financial data "reg_alpha": 1.0, "reg_lambda": 1.0, "min_gain_to_split": 0.01, "path_smooth": 1.0, "verbose": -1, "seed": 42, } callbacks = [ lgb.log_evaluation(period=100), lgb.early_stopping(stopping_rounds=100), ] booster = lgb.train( params, train_data, num_boost_round=2000, valid_sets=[val_data], callbacks=callbacks, ) # Wrap in sklearn-compatible interface clf = LGBMWrapper(booster, y_train) return clf, "lightgbm" def evaluate(model, X_val, y_val, model_type: str) -> dict: """Evaluate model and return metrics dict.""" if isinstance(X_val, np.ndarray): X_df = pd.DataFrame(X_val, columns=FEATURE_COLUMNS) else: X_df = X_val y_pred = model.predict(X_df) probas = model.predict_proba(X_df) accuracy = accuracy_score(y_val, y_pred) report = classification_report( y_val, y_pred, target_names=["LOSS (-1)", "TIMEOUT (0)", "WIN (+1)"], output_dict=True, ) cm = confusion_matrix(y_val, y_pred) # Win-class specific metrics win_mask = y_val == 1 if win_mask.sum() > 0: win_probs = probas[win_mask] win_col_idx = list(model.classes_).index(1) avg_win_prob_for_actual_wins = float(win_probs[:, win_col_idx].mean()) else: avg_win_prob_for_actual_wins = 0.0 # High-confidence win precision win_col_idx = list(model.classes_).index(1) high_conf_mask = probas[:, win_col_idx] >= 0.6 if high_conf_mask.sum() > 0: high_conf_precision = float((y_val[high_conf_mask] == 1).mean()) high_conf_count = int(high_conf_mask.sum()) else: high_conf_precision = 0.0 high_conf_count = 0 # Calibration analysis: do higher P(WIN) quintiles actually win more? win_probs_all = probas[:, win_col_idx] quintile_labels = pd.qcut(win_probs_all, q=5, labels=False, duplicates="drop") calibration = {} for q in sorted(set(quintile_labels)): mask = quintile_labels == q q_probs = win_probs_all[mask] q_actual_win_rate = float((y_val[mask] == 1).mean()) q_actual_loss_rate = float((y_val[mask] == -1).mean()) calibration[f"Q{q+1}"] = { "mean_predicted_win_prob": round(float(q_probs.mean()), 4), "actual_win_rate": round(q_actual_win_rate, 4), "actual_loss_rate": round(q_actual_loss_rate, 4), "count": int(mask.sum()), } # Top decile (top 10% by P(WIN)) — most actionable metric top_decile_threshold = np.percentile(win_probs_all, 90) top_decile_mask = win_probs_all >= top_decile_threshold top_decile_win_rate = ( float((y_val[top_decile_mask] == 1).mean()) if top_decile_mask.sum() > 0 else 0.0 ) top_decile_loss_rate = ( float((y_val[top_decile_mask] == -1).mean()) if top_decile_mask.sum() > 0 else 0.0 ) metrics = { "model_type": model_type, "accuracy": round(accuracy, 4), "per_class": { k: {kk: round(vv, 4) for kk, vv in v.items()} for k, v in report.items() if isinstance(v, dict) }, "confusion_matrix": cm.tolist(), "avg_win_prob_for_actual_wins": round(avg_win_prob_for_actual_wins, 4), "high_confidence_win_precision": round(high_conf_precision, 4), "high_confidence_win_count": high_conf_count, "calibration_quintiles": calibration, "top_decile_win_rate": round(top_decile_win_rate, 4), "top_decile_loss_rate": round(top_decile_loss_rate, 4), "top_decile_threshold": round(float(top_decile_threshold), 4), "top_decile_count": int(top_decile_mask.sum()), "val_samples": len(y_val), } # Print summary logger.info(f"\n{'='*60}") logger.info(f"Model: {model_type}") logger.info(f"Overall Accuracy: {accuracy:.1%}") logger.info("\nPer-class metrics:") logger.info(f"{'':>15} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}") for label, name in [(-1, "LOSS"), (0, "TIMEOUT"), (1, "WIN")]: key = f"{name} ({label:+d})" if key in report: r = report[key] logger.info( f"{name:>15} {r['precision']:>10.3f} {r['recall']:>10.3f} {r['f1-score']:>10.3f} {r['support']:>10.0f}" ) logger.info("\nConfusion Matrix (rows=actual, cols=predicted):") logger.info(f"{'':>10} {'LOSS':>8} {'TIMEOUT':>8} {'WIN':>8}") for i, name in enumerate(["LOSS", "TIMEOUT", "WIN"]): logger.info(f"{name:>10} {cm[i][0]:>8} {cm[i][1]:>8} {cm[i][2]:>8}") logger.info("\nWin-class insights:") logger.info(f" Avg P(WIN) for actual winners: {avg_win_prob_for_actual_wins:.1%}") logger.info( f" High-confidence (>60%) precision: {high_conf_precision:.1%} ({high_conf_count} samples)" ) logger.info("\nCalibration (does higher P(WIN) = more actual wins?):") logger.info( f"{'Quintile':>10} {'Avg P(WIN)':>12} {'Actual WIN%':>12} {'Actual LOSS%':>13} {'Count':>8}" ) for q_name, q_data in calibration.items(): logger.info( f"{q_name:>10} {q_data['mean_predicted_win_prob']:>12.1%} " f"{q_data['actual_win_rate']:>12.1%} {q_data['actual_loss_rate']:>13.1%} " f"{q_data['count']:>8}" ) logger.info("\nTop decile (top 10% by P(WIN)):") logger.info(f" Threshold: P(WIN) >= {top_decile_threshold:.1%}") logger.info( f" Actual win rate: {top_decile_win_rate:.1%} ({int(top_decile_mask.sum())} samples)" ) logger.info(f" Actual loss rate: {top_decile_loss_rate:.1%}") baseline_win = float((y_val == 1).mean()) logger.info(f" Baseline win rate: {baseline_win:.1%}") if baseline_win > 0: logger.info(f" Lift over baseline: {top_decile_win_rate / baseline_win:.2f}x") logger.info(f"{'='*60}") return metrics def main(): parser = argparse.ArgumentParser(description="Train ML model for win probability") parser.add_argument("--dataset", type=str, default="data/ml/training_dataset.parquet") parser.add_argument( "--model", type=str, choices=["tabpfn", "lightgbm", "auto"], default="auto", help="Model type (auto tries TabPFN first, falls back to LightGBM)", ) parser.add_argument( "--val-start", type=str, default="2024-07-01", help="Validation split date (default: 2024-07-01)", ) parser.add_argument( "--max-train-samples", type=int, default=None, help="Limit training samples to the most recent N before val-start", ) parser.add_argument("--output-dir", type=str, default="data/ml") args = parser.parse_args() if args.max_train_samples is not None and args.max_train_samples <= 0: logger.error("--max-train-samples must be a positive integer") sys.exit(1) # Load dataset df = load_dataset(args.dataset) # Split X_train, y_train, X_val, y_val = time_split( df, val_start=args.val_start, max_train_samples=args.max_train_samples, ) if len(X_val) == 0: logger.error(f"No validation data after {args.val_start} — adjust --val-start") sys.exit(1) # Train if args.model == "tabpfn" or args.model == "auto": model, model_type = train_tabpfn(X_train, y_train, X_val, y_val) else: model, model_type = train_lightgbm(X_train, y_train, X_val, y_val) # Evaluate metrics = evaluate(model, X_val, y_val, model_type) # Save model predictor = MLPredictor(model=model, feature_columns=FEATURE_COLUMNS, model_type=model_type) model_path = predictor.save(args.output_dir) logger.info(f"Model saved to {model_path}") # Save metrics metrics_path = os.path.join(args.output_dir, "metrics.json") with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2) logger.info(f"Metrics saved to {metrics_path}") if __name__ == "__main__": main()