diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d08abfc..0d37819 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -245,7 +245,6 @@ class AR_NAR(Base): ) resps_list = sampled[0] - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device=device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list @@ -377,6 +376,7 @@ class AR_NAR(Base): sequence_list = [self._prune(r, audio_stop_token if task_list[i] not in text_task else text_stop_token) for i, r in enumerate(sequence_list)] # remove sequence_list = [ sequence_list[i][start_slice[i]:] for i, task in enumerate( task_list ) ] + return sequence_list diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 8ff356a..4ba7b6d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1507,23 +1507,6 @@ class Base(nn.Module): elif self.causal: logits = [ logit[-self.causal_size:] for logit in logits ] - # entropix sampling - if attentions is not None: - # move to CPU for speedups - logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] - - res = [ sample_entropix( - logit, - attentions[-1], #torch.stack(attentions, dim=1), - temperature, - top_k, - top_p, - min_p, - ) for logit in logits ] - - if res: - return Sampled([ r[0] for r in res], scores, [ r[1] for r in res]) - # (NAR) disable stop token if quant_levels is not None and "ar" in self.capabilities: logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ] @@ -1546,6 +1529,23 @@ class Base(nn.Module): if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] + # (AR) entropix sampling + if attentions is not None and quant_levels is None: + # move to CPU for speedups + logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] + + res = [ sample_entropix( + logit, + attentions[-1], #torch.stack(attentions, dim=1), + temperature, + top_k, + top_p, + min_p, + ) for logit in logits ] + + if res: + return Sampled([ r[0] for r in res], scores, [ r[1] for r in res]) + # perform min_p filtering of our logits if min_p > 0.0: logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] diff --git a/vall_e/samplers.py b/vall_e/samplers.py index e5fd390..fb5b38b 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -356,7 +356,7 @@ def _sample_entropix( probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdims=True) next_token = torch.argmax(probs_sort / Exponential.sample(probs_sort.shape), dim=-1, keepdim=True) - return torch.take_along_dim(probs_idx, next_token, dim=-1) + return torch.take_along_dim(probs_idx, next_token, dim=-1)[0] def sample_entropix( logits, @@ -430,6 +430,12 @@ def sample_entropix( agreement * cfg.ada_score_agree + interaction_strength * cfg.ada_score_int ) + + """ + if 1024 in sample: + return 1000 + """ + return log_prob + confidence_score sample_scores = [ score_sample(sample) for sample in samples ]