add sisth chapter

This commit is contained in:
skindhu 2024-11-12 19:39:07 +08:00
parent 2fc28516a0
commit e5ce40f468
1 changed files with 6 additions and 7 deletions

View File

@ -13,7 +13,6 @@
- [6.1 不同类型的微调](#61-不同类型的微调) - [6.1 不同类型的微调](#61-不同类型的微调)
- [6.2 准备数据集](#62-准备数据集) - [6.2 准备数据集](#62-准备数据集)
- [6.3 创建数据加载器](#63-创建数据加载器) - [6.3 创建数据加载器](#63-创建数据加载器)
- [](#)
- [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型) - [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型)
- [6.5 添加分类头](#65-添加分类头) - [6.5 添加分类头](#65-添加分类头)
- [6.6 计算分类损失和准确率](#66-计算分类损失和准确率) - [6.6 计算分类损失和准确率](#66-计算分类损失和准确率)
@ -410,7 +409,7 @@ print(f"{len(test_loader)} test batches")
本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。 本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。
##
## 6.4 使用预训练权重初始化模型 ## 6.4 使用预训练权重初始化模型
@ -748,12 +747,12 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
if i < num_batches: if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device) input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad(): with torch.no_grad():
logits = model(input_batch)[:, -1, :] #A logits = model(input_batch)[:, -1, :] #A
predicted_labels = torch.argmax(logits, dim=-1) predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0] num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item() correct_predictions += (predicted_labels == target_batch).sum().item()
else: else:
break break
return correct_predictions / num_examples return correct_predictions / num_examples