From 8068f24e35f936b0a3db0486b05ac022f790ff52 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 20 Mar 2025 15:56:15 -0500 Subject: [PATCH] cleaned up parallel nar, i think it's slightly faster but even the smallest model is still slower than ar+nar-len-llama-8... --- vall_e/models/ar_nar_v2.py | 64 ++++++++++---------------------------- vall_e/models/base_v2.py | 30 +++++++++++------- 2 files changed, 34 insertions(+), 60 deletions(-) diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index b9ac51d..fc494b8 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -263,7 +263,7 @@ class AR_NAR_V2(Base_V2): # fill with masked tokens (even though they get masked anyways) resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ] # fill scores - scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ] + scores = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.float32, device=device) for seq_len in len_list ] quant_levels = [ level for _ in range(batch_size) ] null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] @@ -280,10 +280,11 @@ class AR_NAR_V2(Base_V2): # proportion of tokens to remask remask_p = 1.0 / (max_steps * 2) if remasking else 0 # pick the worst scoring tokens to mask off - masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] + masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=0 ).indices for score, seq_len in zip(scores, len_list) ] + # normal masking # mask off inputs - resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] + resps_list = [ torch.stack([resp[:, l].scatter(0, indices.t()[l], self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask is_masked = [ resps == self.mask_token for resps in resps_list ] # timestep inputs @@ -327,53 +328,20 @@ class AR_NAR_V2(Base_V2): logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] ) - l_scores = [] - l_resps_list = [] - # cringe hack because we're able to sample multiple levels at once - for l in range(self.n_resp_levels): - # sample with sampler settings - filtered_sampled = super().sample( - logits=[ logit[l] for logit in logits ], - prev_list=[ resp[..., l] for resp in prev_list ], - quant_levels=quant_levels, + # sample with sampler settings + sampled = super().sample( + logits=logits, + prev_list=resps_list, + quant_levels=quant_levels, - temperature=sampling_temperature, - **sampling_kwargs, - ) + temperature=sampling_temperature, + **sampling_kwargs, + ) - # retrieves unfiltered logits - unfiltered_sampled = super().sample( - logits=[ logit[l] for logit in logits ], - prev_list=[ resp[..., l] for resp in prev_list ], - quant_levels=quant_levels, - - temperature=0.0, - **sampling_kwargs, - ) - - # get sampled tokens - sampled_ids = filtered_sampled.ids - # keep unmasked tokens - l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]) - # get probability scores - l_scores.append([ - # conjugate to have worse scoring tokens picked for topk - 1.0 - - # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) - torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) ) - # use unmodified logit scores for this, as it offers better stability - for scores, masked in zip( unfiltered_sampled.scores, is_masked ) - ]) - - resps_list = [] - scores = [] - - for batch_index in range(batch_size): - score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels - resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1) - - scores.append( score ) - resps_list.append( resp ) + # update resps, filling in the masked tokens with the new tokens + resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ] + # update scores, filling in the + scores = [ 1.0 - torch.where( masked, scores.t(), 1 ) for masked, scores in zip( is_masked, sampled.scores ) ] return resps_list diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 3ddd4b6..fea0c0b 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1236,8 +1236,9 @@ class Base_V2(nn.Module): def sample( self, - logits: list[Tensor], # logit scores - prev_list: list[Tensor] | None = None, # logit scores + logits: Tensor, # logit scores + prev_list: Tensor | None = None, + len_list: Tensor | None = None, **sampling_kwargs, ): # yikes @@ -1265,6 +1266,7 @@ class Base_V2(nn.Module): attentions = sampling_kwargs.get("attentions", None) batch_size = len( logits ) + device = logits[0].device if min_temperature < 0: min_temperature = temperature @@ -1273,14 +1275,16 @@ class Base_V2(nn.Module): entropy = None if prev_list is not None: - seq_lens = map(len, prev_list) - logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ] - # (AR chunkwise) return the last chunkwise piece + seq_lens = [ prev.shape[0] for prev in prev_list ] + elif len_list is not None: + seq_lens = len_list elif self.causal: - seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ] - logits = [ logit[-self.causal_size:] for logit in logits ] + seq_lens = [ self.causal_size for _ in range( batch_size) ] + + logits = [ logit[..., -l:, :] for l, logit in zip(seq_lens, logits) ] # perform min_p filtering of our logits + """ if min_p > 0.0: logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ] @@ -1291,17 +1295,19 @@ class Base_V2(nn.Module): # do top-no logit processing if top_no > 0.0: logits = [ top_no_logits_processing(logit) for logit in logits ] + """ + + probabilities = [ F.softmax(logit, dim=-1) for logit in logits ] + scores = [ torch.max(prob, -1)[0] for prob in probabilities ] - # argmax instead if temperature <= 0.0: - res = [ logit.argmax(dim=-1) for logit in logits ] + res = [ prob.argmax(dim=-1) for prob in probabilities] else: res = [ Categorical(logits=logit / temperature).sample() for logit in logits ] - # calculate token probabilities scores = [ - [ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ] - for logit, tokens in zip(logits, res) + torch.tensor([ [ prob[b, i, token].item() for i, token in enumerate(tokens[b]) ] for b in range(prob.size(0)) ], device=device) + for prob, tokens in zip(probabilities, res) ] return Sampled(res, logits, scores, entropy) \ No newline at end of file