Merge pull request #12 from dblate/patch-10

Update 5.在无标记数据集上进行预训练.md
This commit is contained in:
long_long_ago 2025-05-10 22:39:07 +08:00 committed by GitHub
commit 1fad72948d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -1069,8 +1069,9 @@ def generate(model, idx, max_new_tokens, context_size,
idx_next = torch.argmax(logits, dim=-1, keepdim=True) idx_next = torch.argmax(logits, dim=-1, keepdim=True)
if idx_next == eos_id: #E if idx_next == eos_id: #E
break break
idx_next = idx_next.unsqueeze(1)
idx = torch.cat((idx, idx_next), dim=1) idx = torch.cat((idx, idx_next), dim=1)
return idx return idx
#A For循环与之前相同获取logits仅关注最后的时间步 #A For循环与之前相同获取logits仅关注最后的时间步