TradingAgents/Finllama/DAPT_evaluate.py

498 lines
18 KiB
Python

#!/usr/bin/env python3
"""
DAPT Model Evaluation Script
Evaluates a Domain-Adaptive Pretrained (DAPT) Llama 3.1 model against the baseline
Llama 3.1 model on stock earnings call transcripts dataset.
Computes perplexity scores to measure model performance on domain-specific data.
"""
import os
import sys
import time
import argparse
from typing import List, Optional
import numpy as np
import torch
from datasets import load_dataset
from peft import PeftModel
from tqdm import tqdm
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
class DAPTEvaluator:
"""Evaluator for DAPT model vs baseline perplexity comparison"""
def __init__(
self,
model_id: str = "meta-llama/Llama-3.1-8B",
dapt_model_path: str = "/u/v/d/vdhanuka/llama3_8b_dapt_transcripts_lora",
dataset_path: str = "/u/v/d/vdhanuka/defeatbeta-api-main/stock_earning_call_transcripts.parquet",
sample_size: Optional[int] = None,
sample_percentage: Optional[float] = None,
max_length: int = 1024,
use_qlora: bool = True,
device: Optional[str] = None,
):
"""
Initialize the evaluator.
Args:
model_id: HuggingFace model ID for base model
dapt_model_path: Path to trained DAPT LoRA adapters
dataset_path: Path to evaluation dataset
sample_size: Number of samples to evaluate (mutually exclusive with sample_percentage)
sample_percentage: Percentage of dataset to evaluate (0.0-1.0, mutually exclusive with sample_size)
max_length: Maximum sequence length for evaluation
use_qlora: Whether to use QLoRA quantization
device: Device to use (auto-detected if None)
"""
self.model_id = model_id
self.dapt_model_path = dapt_model_path
self.dataset_path = dataset_path
self.sample_size = sample_size
self.sample_percentage = sample_percentage
self.max_length = max_length
self.use_qlora = use_qlora
# Auto-detect device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
# Hugging Face token from environment
self.hf_token = (
os.getenv("HUGGING_FACE_HUB_TOKEN")
or os.getenv("HF_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
)
# Initialize models and tokenizer
self.tokenizer = None
self.baseline_model = None
self.dapt_model = None
self.eval_texts: Optional[List[str]] = None
print("🚀 Initializing DAPT Evaluator")
print(f" Model: {model_id}")
print(f" DAPT Path: {dapt_model_path}")
print(f" Dataset: {dataset_path}")
print(f" Device: {self.device}")
if self.hf_token:
print(f" HF token: detected in environment")
else:
print(f" HF token: not found (anonymous access)")
if sample_percentage is not None:
print(f" Sample Percentage: {sample_percentage*100:.1f}%")
else:
print(f" Sample Size: {sample_size}")
print(f" Use QLoRA: {use_qlora}")
def setup_tokenizer(self):
"""Load and configure tokenizer"""
print("\n🔧 Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id,
use_fast=True,
token=self.hf_token,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f" Vocab size: {self.tokenizer.vocab_size}")
return self.tokenizer
def load_dataset(self):
"""Load and preprocess evaluation dataset"""
print("\n📊 Loading evaluation dataset...")
try:
ds = load_dataset("parquet", data_files={"eval": self.dataset_path})["eval"]
print(f" Dataset loaded: {len(ds)} examples")
print(f" Columns: {ds.column_names}")
# Flatten transcripts if needed (same logic as training)
if "transcripts" in ds.column_names:
print(" Flattening transcript data...")
def flatten_segments(example):
segments = example.get("transcripts") or []
lines = []
for seg in segments:
if not seg:
continue
speaker = seg.get("speaker")
content = seg.get("content")
if content is None:
continue
if speaker and len(str(speaker)) > 0:
lines.append(f"{speaker}: {content}")
else:
lines.append(str(content))
example["text"] = "\n".join(lines)
return example
ds = ds.map(flatten_segments, desc="Flattening transcripts")
text_column = "text"
else:
# Auto-detect text column
preferred = ["text", "transcript", "content", "body", "cleaned_text"]
text_column = None
for p in preferred:
if p in ds.column_names:
text_column = p
break
if text_column is None:
text_column = ds.column_names[0]
print(f" Using text column: {text_column}")
# Determine sample size
total_samples = len(ds)
if self.sample_percentage is not None:
# Use percentage of dataset
sample_size = int(total_samples * self.sample_percentage)
sample_size = max(1, sample_size)
print(f" Using {self.sample_percentage*100:.1f}% of dataset = {sample_size} samples")
else:
# Use fixed sample size
sample_size = min(self.sample_size, total_samples)
if sample_size is None:
sample_size = min(1000, total_samples)
if sample_size < 1:
sample_size = 1
# Get random sample for more representative evaluation
indices = np.random.choice(total_samples, sample_size, replace=False)
sample_ds = ds.select(indices)
# Filter out empty or very short texts
def is_valid_text(example):
text = example.get(text_column, "")
return text is not None and len(str(text).strip()) > 50
sample_ds = sample_ds.filter(is_valid_text)
self.eval_texts = [ex[text_column] for ex in sample_ds]
print(f" Sampled {len(self.eval_texts)} valid texts for evaluation")
avg_chars = float(np.mean([len(t) for t in self.eval_texts])) if len(self.eval_texts) > 0 else 0.0
print(f" Average text length: {avg_chars:.0f} characters")
return self.eval_texts
except Exception as e:
print(f"❌ Error loading dataset: {e}")
raise
def setup_quantization(self):
"""Setup quantization configuration"""
if not self.use_qlora or not torch.cuda.is_available():
return None
try:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
bnb_4bit_use_double_quant=True,
)
print(" Using 4-bit quantization (QLoRA)")
return bnb_config
except Exception:
print(" BitsAndBytes not available, using standard precision")
return None
def load_baseline_model(self):
"""Load the baseline Llama 3.1 model"""
print("\n🏗️ Loading baseline model...")
bnb_config = self.setup_quantization()
torch_dtype = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else torch.float16 if torch.cuda.is_available() else torch.float32
)
self.baseline_model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=torch_dtype,
quantization_config=bnb_config,
low_cpu_mem_usage=True,
token=self.hf_token,
)
self.baseline_model.eval()
print(" Baseline model loaded successfully")
return self.baseline_model
def load_dapt_model(self):
"""Load the DAPT model with LoRA adapters"""
print("\n🎯 Loading DAPT model...")
if not os.path.exists(self.dapt_model_path):
print(f"❌ DAPT model path not found: {self.dapt_model_path}")
return None
try:
bnb_config = self.setup_quantization()
torch_dtype = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
else torch.float16 if torch.cuda.is_available() else torch.float32
)
# Load base model
dapt_base_model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto" if torch.cuda.is_available() else None,
torch_dtype=torch_dtype,
quantization_config=bnb_config,
low_cpu_mem_usage=True,
token=self.hf_token,
)
# Load LoRA adapters
self.dapt_model = PeftModel.from_pretrained(dapt_base_model, self.dapt_model_path)
self.dapt_model.eval()
print(" DAPT model loaded successfully")
return self.dapt_model
except Exception as e:
print(f"❌ Error loading DAPT model: {e}")
return None
def compute_perplexity(self, model, texts: List[str]) -> float:
"""
Compute perplexity for a model on given texts.
Args:
model: The language model to evaluate
texts: List of text strings
Returns:
Perplexity score
"""
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for text in tqdm(texts, desc="Computing perplexity", unit="text"):
# Tokenize
encodings = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding=False,
)
input_ids = encodings.input_ids.to(self.device)
if len(input_ids[0]) <= 1:
continue
# Create labels (same as input_ids for causal LM)
labels = input_ids.clone()
# Forward pass
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
# Accumulate loss weighted by sequence length
seq_len = len(input_ids[0])
total_loss += loss.item() * seq_len
total_tokens += seq_len
if total_tokens == 0:
return float("inf")
# Compute average loss and perplexity
avg_loss = total_loss / total_tokens
perplexity = float(np.exp(avg_loss))
return perplexity
def evaluate_models(self):
"""Evaluate both baseline and DAPT models"""
if self.eval_texts is None:
raise ValueError("Evaluation texts not loaded. Call load_dataset() first.")
results = {}
# Evaluate baseline model
if self.baseline_model is None:
self.load_baseline_model()
print("\n📈 Evaluating BASELINE model...")
start_time = time.time()
baseline_ppl = self.compute_perplexity(self.baseline_model, self.eval_texts)
baseline_time = time.time() - start_time
results["baseline"] = {
"perplexity": baseline_ppl,
"eval_time": baseline_time,
}
print(f" Perplexity: {baseline_ppl:.4f}")
print(f" Evaluation time: {baseline_time:.2f} seconds")
# Evaluate DAPT model
if self.dapt_model is None:
self.dapt_model = self.load_dapt_model()
if self.dapt_model is not None:
print("\n📈 Evaluating DAPT model...")
start_time = time.time()
dapt_ppl = self.compute_perplexity(self.dapt_model, self.eval_texts)
dapt_time = time.time() - start_time
results["dapt"] = {
"perplexity": dapt_ppl,
"eval_time": dapt_time,
}
print(f" Perplexity: {dapt_ppl:.4f}")
print(f" Evaluation time: {dapt_time:.2f} seconds")
else:
print("\n⚠️ DAPT model not available for evaluation")
results["dapt"] = None
return results
def print_results(self, results):
"""Print formatted evaluation results"""
print("\n" + "=" * 70)
print("🎯 EVALUATION RESULTS")
print("=" * 70)
print(f"Dataset: {self.dataset_path}")
print(f"Samples evaluated: {len(self.eval_texts)}")
print(f"Max sequence length: {self.max_length}")
print()
if "baseline" in results and results["baseline"]:
baseline_ppl = results["baseline"]["perplexity"]
baseline_time = results["baseline"]["eval_time"]
print("BASELINE LLAMA 3.1:")
print(f" Perplexity: {baseline_ppl:.4f}")
print(f" Evaluation time: {baseline_time:.2f} seconds")
if "dapt" in results and results["dapt"]:
dapt_ppl = results["dapt"]["perplexity"]
dapt_time = results["dapt"]["eval_time"]
print("\nDAPT MODEL:")
print(f" Perplexity: {dapt_ppl:.4f}")
print(f" Evaluation time: {dapt_time:.2f} seconds")
# Comparison
if "baseline" in results and results["baseline"]:
baseline_ppl = results["baseline"]["perplexity"]
improvement = ((baseline_ppl - dapt_ppl) / baseline_ppl) * 100.0
print("\nCOMPARISON:")
print(f" Baseline PPL: {baseline_ppl:.4f}")
print(f" DAPT PPL: {dapt_ppl:.4f}")
print(f" Improvement: {improvement:.2f}%")
if dapt_ppl < baseline_ppl:
print("✅ SUCCESS: DAPT model outperforms baseline!")
print(" The domain-adaptive pretraining improved performance on earnings call data.")
else:
print("⚠️ NOTE: DAPT model does not outperform baseline")
print(" Consider adjusting training parameters or dataset.")
else:
print("\n❌ DAPT model evaluation failed or not available")
print("\n" + "-" * 70)
print("📝 INTERPRETATION")
print("-" * 70)
print("Perplexity measures how well the model predicts the next token in sequences.")
print("Lower perplexity = better predictive performance on the domain.")
print("Typical perplexity ranges: 10-100+ (lower is better)")
print()
print("Earnings call transcripts contain specialized financial language,")
print("so domain adaptation should ideally reduce perplexity compared to baseline.")
def run_evaluation(self):
"""Run the complete evaluation pipeline"""
try:
# Setup
self.setup_tokenizer()
self.load_dataset()
self.load_baseline_model()
self.load_dapt_model()
# Evaluate
results = self.evaluate_models()
# Display results
self.print_results(results)
return results
except Exception as e:
print(f"❌ Evaluation failed: {e}")
import traceback
traceback.print_exc()
return None
def main():
"""Main function with command line argument parsing"""
parser = argparse.ArgumentParser(description="Evaluate DAPT model vs baseline perplexity")
parser.add_argument("--model-id", default="meta-llama/Llama-3.1-8B", help="Base model ID")
parser.add_argument(
"--dapt-path",
default="/u/v/d/vdhanuka/llama3_8b_dapt_transcripts_lora",
help="Path to DAPT LoRA adapters",
)
parser.add_argument(
"--dataset",
default="/u/v/d/vdhanuka/defeatbeta-api-main/stock_earning_call_transcripts.parquet",
help="Path to evaluation dataset",
)
parser.add_argument(
"--sample-size",
type=int,
default=None,
help="Number of samples to evaluate (mutually exclusive with --sample-percentage)",
)
parser.add_argument(
"--sample-percentage",
type=float,
default=None,
help="Percentage of dataset to evaluate (0.0-1.0, mutually exclusive with --sample-size)",
)
parser.add_argument("--max-length", type=int, default=1024, help="Maximum sequence length")
parser.add_argument("--no-qlora", action="store_true", help="Disable QLoRA quantization")
parser.add_argument("--device", default=None, help="Device to use (cuda/cpu)")
args = parser.parse_args()
# Validate mutually exclusive arguments
if args.sample_size is not None and args.sample_percentage is not None:
parser.error("--sample-size and --sample-percentage are mutually exclusive")
if args.sample_size is None and args.sample_percentage is None:
args.sample_size = 1000 # Default to 500 samples
# Create evaluator
evaluator = DAPTEvaluator(
model_id=args.model_id,
dapt_model_path=args.dapt_path,
dataset_path=args.dataset,
sample_size=args.sample_size,
sample_percentage=args.sample_percentage,
max_length=args.max_length,
use_qlora=not args.no_qlora,
device=args.device,
)
# Run evaluation
results = evaluator.run_evaluation()
return results
if __name__ == "__main__":
# Set random seed for reproducible sampling
np.random.seed(42)
torch.manual_seed(42)
main()