fixed mirostat issue
This commit is contained in:
parent
99e980d323
commit
47b3077415
|
@ -595,7 +595,7 @@ class Base(nn.Module):
|
||||||
if beam_width > 1:
|
if beam_width > 1:
|
||||||
candidates = top_k_logits_list( logits, beam_width )
|
candidates = top_k_logits_list( logits, beam_width )
|
||||||
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
|
||||||
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
|
scores = [ logits[batch].flatten()[token].to(logits[batch].device) for batch, token in candidates ]
|
||||||
return res, scores
|
return res, scores
|
||||||
|
|
||||||
# and sample
|
# and sample
|
||||||
|
|
Loading…
Reference in New Issue
Block a user