From 8b3d1cf70a5786a6d9e5c3d5c132d9bce9d44007 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 9 Nov 2024 15:07:43 -0600 Subject: [PATCH] Something's Wrong --- docs/models.md | 1 - vall_e/models/base.py | 9 ++++++- vall_e/models/nar.py | 14 ++++++----- vall_e/samplers.py | 57 +++++++++++++++++++++++++++++-------------- 4 files changed, 55 insertions(+), 26 deletions(-) diff --git a/docs/models.md b/docs/models.md index af49e1b..a1cd6ea 100644 --- a/docs/models.md +++ b/docs/models.md @@ -51,7 +51,6 @@ However, having a pure NAR is challenging, as you need to both explicitly provid * The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively. * The latter however proves to be challenging, as generating tokens from nothing in one step is not possible. * diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level. - * embedding the current timestep is *required*, despite this technically being encoded in how many masked tokens exist within a sequence. * the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step. The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise. diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c13746e..aebd075 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1739,7 +1739,14 @@ class Base(nn.Module): # perform repetition penalizing if prev_list is not None and repetition_penalty != 1.0: # to-do: figure out a faster way to handle tolist() - logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] + + # penalize non-autoregressively + if quant_levels is not None: + #logits = [ reptition_penalize(logit, previous=logit.argmax(dim=1).tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit in logits ] + logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] + # penalize autoregressively + else: + logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 0b5d2b6..5855580 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -252,12 +252,13 @@ class NAR(Base): test_artifact = np.load(path, allow_pickle=True)[()] text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ] resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ] - proms_list = [ resps for resps in resps_list ] + proms_list = [ resps[:75*3, :] for resps in resps_list ] + #proms_list = [ resps for resps in resps_list ] len_list = [ resps.shape[0] for resps in resps_list ] """ _super = super() - def demask_sampling( seq_len, max_steps=20, temperature=0.3 ): + def demask_sampling( seq_len, max_steps=5, temperature=1.0 ): starting_temperature = temperature input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token @@ -268,12 +269,13 @@ class NAR(Base): start_noise = 0.0 end_noise = 1.0 + sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... # use hardcoded reference file to test inference capabilities if test_artifact is not None: # because we "set" it later on, it's not implicitly captured nonlocal resps_list - start_noise = 0.0 + start_noise = 0.5 noise_p = math.cos( start_noise * math.pi * 0.5 ) input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device ) @@ -348,15 +350,15 @@ class NAR(Base): # sample with gumbelnoise # I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model - #sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) - sampled_ids = filtered_tokens + sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) + #sampled_ids = filtered_tokens # keep unmasked tokens input_ids = torch.where( is_masked, sampled_ids, input_ids ) # update scores (conjugated to put the worst scores at the top) scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device) - # print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) + print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) return input_ids diff --git a/vall_e/samplers.py b/vall_e/samplers.py index be3539f..6927f0c 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -8,29 +8,52 @@ from torch import Tensor, einsum, nn from einops import rearrange from dataclasses import asdict, dataclass, field +def clamp(n, lo, hi): + return max(lo, min(n, hi)) + # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is -def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False, limit=75 ): + +# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future) +def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ): if factor == 1.0 or previous is None: return logits - unique = set() - priors = reversed(previous) - for distance, token in enumerate(priors): - # rep-pen range - if limit and distance >= limit: - continue - # skip if we're only applying the decay once - if one_time and token in unique: - continue + seq_len = logits.shape[0] + prev_len = len( previous ) + + # apply autoregressively + if prev_len < seq_len: + unique = set() + priors = reversed(previous) + for i, token in enumerate(priors): + # rep-pen range + if limit and i >= limit: + continue + # skip if we're only applying the decay once + if one_time and token in unique: + continue + + distance = i + 1 + logits[-1, token] /= factor * (distance ** decay) + + # add to set if we care about it + if one_time: + unique.add(token) + # apply non-autoregressively + else: + for i, token in enumerate( previous ): + # apply to next token + start = i + 1 + # apply either up to limit tokens, or to the end + end = start + limit if limit > 0 else seq_len + start = clamp(0, seq_len - 1, start) + end = clamp(0, seq_len - 1, end) + for j in range( start, end ): + distance = j - i + logits[j, token] /= factor * (distance ** decay) - distance += 1 - logits[:, token] /= factor * (distance ** decay) - - # add to set if we care about it - if one_time: - unique.add(token) return logits @@ -379,8 +402,6 @@ def _sample_entropix( min_p=0.0, cfg=EntropixSamplerConfig(), ): - def clamp(n, lo, hi): - return max(lo, min(n, hi)) if top_k == 0: top_k = logits.shape[-1]