commit
1fad72948d
|
|
@ -1069,8 +1069,9 @@ def generate(model, idx, max_new_tokens, context_size,
|
||||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
if idx_next == eos_id: #E
|
if idx_next == eos_id: #E
|
||||||
break
|
break
|
||||||
|
idx_next = idx_next.unsqueeze(1)
|
||||||
idx = torch.cat((idx, idx_next), dim=1)
|
idx = torch.cat((idx, idx_next), dim=1)
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
#A For循环与之前相同:获取logits,仅关注最后的时间步
|
#A For循环与之前相同:获取logits,仅关注最后的时间步
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue