diff --git a/cn-Book/5.在无标记数据集上进行预训练.md b/cn-Book/5.在无标记数据集上进行预训练.md index 7097f3e..079d8f0 100644 --- a/cn-Book/5.在无标记数据集上进行预训练.md +++ b/cn-Book/5.在无标记数据集上进行预训练.md @@ -1069,8 +1069,9 @@ def generate(model, idx, max_new_tokens, context_size, idx_next = torch.argmax(logits, dim=-1, keepdim=True) if idx_next == eos_id: #E break + idx_next = idx_next.unsqueeze(1) idx = torch.cat((idx, idx_next), dim=1) - return idx +return idx #A For循环与之前相同:获取logits,仅关注最后的时间步