This commit is contained in:
mrq 2024-10-12 10:41:35 -05:00
parent 3d6ef9666b
commit 666e8038fb
3 changed files with 25 additions and 19 deletions

View File

@ -245,7 +245,6 @@ class AR_NAR(Base):
)
resps_list = sampled[0]
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
@ -377,6 +376,7 @@ class AR_NAR(Base):
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
# remove <bos>
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
return sequence_list

View File

@ -1507,23 +1507,6 @@ class Base(nn.Module):
elif self.causal:
logits = [ logit[-self.causal_size:] for logit in logits ]
# entropix sampling
if attentions is not None:
# move to CPU for speedups
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
res = [ sample_entropix(
logit,
attentions[-1], #torch.stack(attentions, dim=1),
temperature,
top_k,
top_p,
min_p,
) for logit in logits ]
if res:
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
# (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
@ -1546,6 +1529,23 @@ class Base(nn.Module):
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR) entropix sampling
if attentions is not None and quant_levels is None:
# move to CPU for speedups
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
res = [ sample_entropix(
logit,
attentions[-1], #torch.stack(attentions, dim=1),
temperature,
top_k,
top_p,
min_p,
) for logit in logits ]
if res:
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
# perform min_p filtering of our logits
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]

View File

@ -356,7 +356,7 @@ def _sample_entropix(
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True)
next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True)
return torch.take_along_dim(probs_idx, next_token, dim=-1)
return torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
def sample_entropix(
logits,
@ -430,6 +430,12 @@ def sample_entropix(
agreement * cfg.ada_score_agree +
interaction_strength * cfg.ada_score_int
)
"""
if 1024 in sample:
return 1000
"""
return log_prob + confidence_score
sample_scores = [ score_sample(sample) for sample in samples ]