From f399a51b7937c6cefe9b81ef1f4c9538015e25ec Mon Sep 17 00:00:00 2001 From: yuhui <173983476@qq.com> Date: Fri, 9 May 2025 19:59:55 +0800 Subject: [PATCH] =?UTF-8?q?Update=205.=E5=9C=A8=E6=97=A0=E6=A0=87=E8=AE=B0?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E4=B8=8A=E8=BF=9B=E8=A1=8C=E9=A2=84?= =?UTF-8?q?=E8=AE=AD=E7=BB=83.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cn-Book/5.在无标记数据集上进行预训练.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cn-Book/5.在无标记数据集上进行预训练.md b/cn-Book/5.在无标记数据集上进行预训练.md index 7097f3e..079d8f0 100644 --- a/cn-Book/5.在无标记数据集上进行预训练.md +++ b/cn-Book/5.在无标记数据集上进行预训练.md @@ -1069,8 +1069,9 @@ def generate(model, idx, max_new_tokens, context_size, idx_next = torch.argmax(logits, dim=-1, keepdim=True) if idx_next == eos_id: #E break + idx_next = idx_next.unsqueeze(1) idx = torch.cat((idx, idx_next), dim=1) - return idx +return idx #A For循环与之前相同:获取logits,仅关注最后的时间步