diff --git a/cn-Book/6.用于分类任务的微调.md b/cn-Book/6.用于分类任务的微调.md index c024a38..3f7b93b 100644 --- a/cn-Book/6.用于分类任务的微调.md +++ b/cn-Book/6.用于分类任务的微调.md @@ -13,7 +13,6 @@ - [6.1 不同类型的微调](#61-不同类型的微调) - [6.2 准备数据集](#62-准备数据集) - [6.3 创建数据加载器](#63-创建数据加载器) -- [](#) - [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型) - [6.5 添加分类头](#65-添加分类头) - [6.6 计算分类损失和准确率](#66-计算分类损失和准确率) @@ -410,7 +409,7 @@ print(f"{len(test_loader)} test batches") 本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。 -## + ## 6.4 使用预训练权重初始化模型 @@ -748,12 +747,12 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) - with torch.no_grad(): - logits = model(input_batch)[:, -1, :] #A - predicted_labels = torch.argmax(logits, dim=-1) + with torch.no_grad(): + logits = model(input_batch)[:, -1, :] #A + predicted_labels = torch.argmax(logits, dim=-1) - num_examples += predicted_labels.shape[0] - correct_predictions += (predicted_labels == target_batch).sum().item() + num_examples += predicted_labels.shape[0] + correct_predictions += (predicted_labels == target_batch).sum().item() else: break return correct_predictions / num_examples