add sisth chapter
This commit is contained in:
parent
2fc28516a0
commit
e5ce40f468
|
|
@ -13,7 +13,6 @@
|
||||||
- [6.1 不同类型的微调](#61-不同类型的微调)
|
- [6.1 不同类型的微调](#61-不同类型的微调)
|
||||||
- [6.2 准备数据集](#62-准备数据集)
|
- [6.2 准备数据集](#62-准备数据集)
|
||||||
- [6.3 创建数据加载器](#63-创建数据加载器)
|
- [6.3 创建数据加载器](#63-创建数据加载器)
|
||||||
- [](#)
|
|
||||||
- [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型)
|
- [6.4 使用预训练权重初始化模型](#64-使用预训练权重初始化模型)
|
||||||
- [6.5 添加分类头](#65-添加分类头)
|
- [6.5 添加分类头](#65-添加分类头)
|
||||||
- [6.6 计算分类损失和准确率](#66-计算分类损失和准确率)
|
- [6.6 计算分类损失和准确率](#66-计算分类损失和准确率)
|
||||||
|
|
@ -410,7 +409,7 @@ print(f"{len(test_loader)} test batches")
|
||||||
|
|
||||||
本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。
|
本章的数据准备工作到此结束,接下来我们将初始化模型以准备进行微调。
|
||||||
|
|
||||||
##
|
|
||||||
|
|
||||||
## 6.4 使用预训练权重初始化模型
|
## 6.4 使用预训练权重初始化模型
|
||||||
|
|
||||||
|
|
@ -748,12 +747,12 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||||
if i < num_batches:
|
if i < num_batches:
|
||||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_batch)[:, -1, :] #A
|
logits = model(input_batch)[:, -1, :] #A
|
||||||
predicted_labels = torch.argmax(logits, dim=-1)
|
predicted_labels = torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
num_examples += predicted_labels.shape[0]
|
num_examples += predicted_labels.shape[0]
|
||||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return correct_predictions / num_examples
|
return correct_predictions / num_examples
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue