fixed mirostat issue
This commit is contained in:
parent
99e980d323
commit
47b3077415
|
@ -595,7 +595,7 @@ class Base(nn.Module):
|
|||
if beam_width > 1:
|
||||
candidates = top_k_logits_list( logits, beam_width )
|
||||
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
||||
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
||||
scores = [ logits[batch].flatten()[token].to(logits[batch].device) for batch, token in candidates ]
|
||||
return res, scores
|
||||
|
||||
# and sample
|
||||
|
|
Loading…
Reference in New Issue
Block a user