fixed mirostat issue

This commit is contained in:
mrq 2023-10-10 18:09:49 -05:00
parent 99e980d323
commit 47b3077415

View File

@ -595,7 +595,7 @@ class Base(nn.Module):
if beam_width > 1:
candidates = top_k_logits_list( logits, beam_width )
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
scores = [ logits[batch].flatten()[token].to(logits[batch].device) for batch, token in candidates ]
return res, scores
# and sample