diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e2a0f8e..242c14d 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -316,9 +316,9 @@ class AR_NAR(Base): prev_list = resps_list # sample with gumbelnoise - # I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model - sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ] - #sampled_ids = filtered_sampled[0] + # This actually lobotomizes things + #sampled_ids = [ gumbel_sample( logits, temperature=temperature, dim=-1 ) for logits in filtered_sampled.logits[0] ] + sampled_ids = filtered_sampled[0] # keep unmasked tokens resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] @@ -447,7 +447,8 @@ class AR_NAR(Base): logits=logits, prev_list=prev_list, quant_levels=quant_levels, - **sampling_kwargs, + #temperature=0.0, + **(sampling_kwargs | {"temperature": 0.0}), ) resps_list = sampled[0] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0867e3d..809773d 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1710,6 +1710,10 @@ class Base(nn.Module): if min_temperature < 0: min_temperature = temperature + # pick last RVQ level + if prev_list is not None: + prev_list = [ prevs if prevs.dim() == 1 else prevs[:, -1] for prevs in prev_list ] + scores = None entropy = None #logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] @@ -1763,15 +1767,12 @@ class Base(nn.Module): # perform repetition penalizing if prev_list is not None and repetition_penalty != 1.0: - # to-do: figure out a faster way to handle tolist() - # penalize non-autoregressively if quant_levels is not None: - #logits = [ reptition_penalize(logit, previous=logit.argmax(dim=1).tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit in logits ] - logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] + logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # penalize autoregressively else: - logits = [ reptition_penalize(logit, previous=prevs.tolist() if prevs.dim() == 1 else prevs[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] + logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: @@ -1794,7 +1795,7 @@ class Base(nn.Module): # do DRY sampling if dry_multiplier > 0.0 and prev_list is not None: - logits = [ dry_sampling(logit, previous=prevs[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ] + logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ] # do mirostat sampling # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 1d06d1a..ae7f1e8 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -13,15 +13,17 @@ from .utils import clamp # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once # `decay` is a factor that will exponentially apply to how far away it is - -# this is split between applying autoregressively (applying to the last token, starting from the end), and applying non-autoregressively (starting from the beginning, and applying to tokens in the future) -def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ): +def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=True, limit=0 ): if factor == 1.0 or previous is None: return logits + unique = set() - priors = reversed(previous) - for distance, token in enumerate(priors): + is_nar = previous.shape[0] == logits.shape[0] + + for i, token in enumerate( previous ): + distance = previous.shape[0] - i + # rep-pen range if limit and distance >= limit: continue @@ -29,8 +31,14 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F if one_time and token in unique: continue - distance += 1 - logits[:, token] /= factor * (distance ** decay) + start = None + end = None + + # apply only to future tokens + if is_nar and i < logits.shape[0]: + start = i + 1 + + logits[start:end, token] /= factor * (distance ** decay) # add to set if we care about it if one_time: @@ -38,50 +46,6 @@ def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=F return logits -""" -# I do not know why this is a regression... -def reptition_penalize( logits, previous=None, factor=1.0, decay=0.0, one_time=False, limit=75 ): - if factor == 1.0 or previous is None: - return logits - - seq_len = logits.shape[0] - prev_len = len( previous ) - - # apply autoregressively - if prev_len < seq_len: - unique = set() - priors = reversed(previous) - for i, token in enumerate(priors): - # rep-pen range - if limit and i >= limit: - continue - # skip if we're only applying the decay once - if one_time and token in unique: - continue - - distance = i + 1 - logits[-1, token] /= factor * (distance ** decay) - - # add to set if we care about it - if one_time: - unique.add(token) - # apply non-autoregressively - else: - for i, token in enumerate( previous ): - # apply to next token - start = i + 1 - # apply either up to limit tokens, or to the end - end = start + limit if limit > 0 else seq_len - start = clamp(start, 0, seq_len - 1) - end = clamp(end, 0, seq_len - 1) - for j in range( start, end ): - distance = j - i - logits[j, token] /= factor * (distance ** decay) - - - return logits -""" - # Simple "filter" that modifies the logit for the stop token, based on the sequence length # `length` is the length of the sequence currently # `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences diff --git a/vall_e/webui.py b/vall_e/webui.py index e3f4d0c..106adcc 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -259,6 +259,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): gr.Info("Inferencing...") sampling_kwargs = dict( + max_steps=args.max_steps, + max_levels=args.max_levels, max_duration=args.max_duration, ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature, min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature, @@ -467,7 +469,7 @@ with ui: layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.") with gr.Tab("Experimental Settings", visible=cfg.experimental): with gr.Row(): - layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=50, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.") + layout["inference_tts"]["inputs"]["max-steps"] = gr.Slider(value=25, minimum=1, maximum=500, step=1, label="Max NAR Steps", info="Limits how many steps to perform in the NAR (demask) pass.") layout["inference_tts"]["inputs"]["max-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.") with gr.Row():