add sisth chapter
This commit is contained in:
parent
2fc28516a0
commit
e5ce40f468
|
|
@ -13,7 +13,6 @@
|
|||
- [6.1 不同类型的微调](#61-不同类型的微调)
|
||||
- [6.2 准备数据集](#62-准备数据集)
|
||||
- [6.3 创建数据加载器](#63-创建数据加载器)
|
||||
- [](#)
|
||||
- [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型)
|
||||
- [6.5 添加分类头](#65-添加分类头)
|
||||
- [6.6 计算分类损失和准确率](#66-计算分类损失和准确率)
|
||||
|
|
@ -410,7 +409,7 @@ print(f"{len(test_loader)} test batches")
|
|||
|
||||
本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。
|
||||
|
||||
##
|
||||
|
||||
|
||||
## 6.4 使用预训练权重初始化模型
|
||||
|
||||
|
|
@ -748,12 +747,12 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
|||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_batch)[:, -1, :] #A
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
with torch.no_grad():
|
||||
logits = model(input_batch)[:, -1, :] #A
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
else:
|
||||
break
|
||||
return correct_predictions / num_examples
|
||||
|
|
|
|||
Loading…
Reference in New Issue