From e5ce40f468dfb5de874c1600ba4aaac920e6abb3 Mon Sep 17 00:00:00 2001 From: skindhu Date: Tue, 12 Nov 2024 19:39:07 +0800 Subject: [PATCH] add sisth chapter --- cn-Book/6.用于分类任务的微调.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cn-Book/6.用于分类任务的微调.md b/cn-Book/6.用于分类任务的微调.md index c024a38..3f7b93b 100644 --- a/cn-Book/6.用于分类任务的微调.md +++ b/cn-Book/6.用于分类任务的微调.md @@ -13,7 +13,6 @@ - [6.1 不同类型的微调](#61-不同类型的微调) - [6.2 准备数据集](#62-准备数据集) - [6.3 创建数据加载器](#63-创建数据加载器) -- [](#) - [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型) - [6.5 添加分类头](#65-添加分类头) - [6.6 计算分类损失和准确率](#66-计算分类损失和准确率) @@ -410,7 +409,7 @@ print(f"{len(test_loader)} test batches") 本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。 -## + ## 6.4 使用预训练权重初始化模型 @@ -748,12 +747,12 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) - with torch.no_grad(): - logits = model(input_batch)[:, -1, :] #A - predicted_labels = torch.argmax(logits, dim=-1) + with torch.no_grad(): + logits = model(input_batch)[:, -1, :] #A + predicted_labels = torch.argmax(logits, dim=-1) - num_examples += predicted_labels.shape[0] - correct_predictions += (predicted_labels == target_batch).sum().item() + num_examples += predicted_labels.shape[0] + correct_predictions += (predicted_labels == target_batch).sum().item() else: break return correct_predictions / num_examples