ugh
This commit is contained in:
parent
3d6ef9666b
commit
666e8038fb
|
@ -245,7 +245,6 @@ class AR_NAR(Base):
|
|||
)
|
||||
|
||||
resps_list = sampled[0]
|
||||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
@ -377,6 +376,7 @@ class AR_NAR(Base):
|
|||
sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)]
|
||||
# remove <bos>
|
||||
sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ]
|
||||
|
||||
return sequence_list
|
||||
|
||||
|
||||
|
|
|
@ -1507,23 +1507,6 @@ class Base(nn.Module):
|
|||
elif self.causal:
|
||||
logits = [ logit[-self.causal_size:] for logit in logits ]
|
||||
|
||||
# entropix sampling
|
||||
if attentions is not None:
|
||||
# move to CPU for speedups
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
res = [ sample_entropix(
|
||||
logit,
|
||||
attentions[-1], #torch.stack(attentions, dim=1),
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
) for logit in logits ]
|
||||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
||||
|
||||
# (NAR) disable stop token
|
||||
if quant_levels is not None and "ar" in self.capabilities:
|
||||
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
@ -1546,6 +1529,23 @@ class Base(nn.Module):
|
|||
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# (AR) entropix sampling
|
||||
if attentions is not None and quant_levels is None:
|
||||
# move to CPU for speedups
|
||||
logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
|
||||
|
||||
res = [ sample_entropix(
|
||||
logit,
|
||||
attentions[-1], #torch.stack(attentions, dim=1),
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
min_p,
|
||||
) for logit in logits ]
|
||||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res], scores, [ r[1] for r in res])
|
||||
|
||||
# perform min_p filtering of our logits
|
||||
if min_p > 0.0:
|
||||
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
|
||||
|
|
|
@ -356,7 +356,7 @@ def _sample_entropix(
|
|||
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True)
|
||||
|
||||
next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True)
|
||||
return torch.take_along_dim(probs_idx, next_token, dim=-1)
|
||||
return torch.take_along_dim(probs_idx, next_token, dim=-1)[0]
|
||||
|
||||
def sample_entropix(
|
||||
logits,
|
||||
|
@ -430,6 +430,12 @@ def sample_entropix(
|
|||
agreement * cfg.ada_score_agree +
|
||||
interaction_strength * cfg.ada_score_int
|
||||
)
|
||||
|
||||
"""
|
||||
if 1024 in sample:
|
||||
return 1000
|
||||
"""
|
||||
|
||||
return log_prob + confidence_score
|
||||
|
||||
sample_scores = [ score_sample(sample) for sample in samples ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user