From fc8dfd86178ae1aa5436b84cf49e51bc000be0d2 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 18 Oct 2024 16:55:00 -0500 Subject: [PATCH] made greedy AR sampling viable (and preferable), with caveats (per comment in vall_e.models.ar_nar) --- vall_e/inference.py | 2 +- vall_e/models/ar_nar.py | 20 +++++++++++++++++++- vall_e/models/arch/llama.py | 4 ++-- vall_e/models/base.py | 14 ++++++-------- vall_e/utils/utils.py | 2 ++ vall_e/webui.py | 2 +- 6 files changed, 31 insertions(+), 13 deletions(-) diff --git a/vall_e/inference.py b/vall_e/inference.py index 44d2e1e..a530813 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -234,7 +234,7 @@ class TTS(): if "nar" in engine.hyper_config.capabilities: model_nar = engine.module - set_seed(seed) + seed = set_seed(seed) if task == "stt": resp = self.encode_audio( references ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 84c3ee4..8454f53 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -172,7 +172,6 @@ class AR_NAR(Base): ... else: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) - inputs = self.inputs( text_list=text_list, @@ -268,6 +267,14 @@ class AR_NAR(Base): scores = [ 1.0 ] * sampling_beam_width entropies = [] + # ick + low_temperature = sampling_repetition_penalty == 1.0 and sampling_temperature < 0.5 + low_temperature_range = cfg.dataset.frames_per_second * 3 + + original_sampling_temperature = sampling_temperature + original_sampling_repetition_penalty = sampling_repetition_penalty + original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay + for i, sequence in enumerate( sequence_list ): # add to text for STT if task_list[i] in text_task: @@ -284,6 +291,17 @@ class AR_NAR(Base): text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] + # greedy sampling in the AR *does* work, but requires some quasi-exotic sampling to work around the initial burst of garbage from polluting the rest of the sequence + # naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio + # however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled + # to-do: tune these values, maybe have it factor based on confidence scores or something + # to-do: see if instead just prefixing with blank audio overcomes the initla noise anyways + if low_temperature: + enabled = n < low_temperature_range + sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty + sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay + sampling_temperature = original_sampling_temperature if enabled else 1.0 + inputs = self.inputs( text_list=text_list, proms_list=proms_list, diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 7a587ae..3b9bed7 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -114,7 +114,7 @@ except Exception as e: # to-do: find a better way to query for if there's available kernels since these return true regardless if torch.backends.cuda.flash_sdp_enabled(): - AVAILABLE_ATTENTIONS.append("flash") + AVAILABLE_ATTENTIONS.append("flash_(sdpa)") if torch.backends.cuda.mem_efficient_sdp_enabled(): AVAILABLE_ATTENTIONS.append("mem_efficient") @@ -141,7 +141,7 @@ class LlamaAttention_Adapted(LlamaAttention): self.mode = torch.nn.attention.SDPBackend.MATH elif self.mode == "mem_efficient": self.mode = torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION - elif self.mode == "flash": + elif self.mode == "flash_(sdpa)": self.mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION elif self.mode == "cudnn": self.mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cb08f2e..03b1f43 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1535,12 +1535,6 @@ class Base(nn.Module): if quant_levels is None and "len" in self.capabilities: logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ] - # argmax instead - if temperature <= 0.0: - res = [ logit.argmax(dim=1) for logit in logits ] - scores = None - return Sampled(res, scores, entropy) - # perform repetition penalizing if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: # to-do: figure out a faster way to handle tolist() @@ -1562,7 +1556,7 @@ class Base(nn.Module): # epsilon float comparison because I don't trust Python if abs(temperature - min_temperature) >= 0.001: logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ] - else: + elif temperature > 0.0: logits = [ logit / temperature for logit in logits ] # do DRY sampling @@ -1585,6 +1579,10 @@ class Base(nn.Module): scores = [ logits[batch].flatten()[token] for batch, token in candidates ] # basic sampling else: - res = [ Categorical(logits=logit).sample() for logit in logits ] + # argmax instead + if temperature <= 0.0: + res = [ logit.argmax(dim=1) for logit in logits ] + else: + res = [ Categorical(logits=logit).sample() for logit in logits ] return Sampled(res, scores, entropy) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 08d6912..4831a91 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -90,6 +90,8 @@ def set_seed(seed=None): np.random.seed(seed) torch.manual_seed(seed) + return seed + def _get_named_modules(module, attrname): for name, module in module.named_modules(): if hasattr(module, attrname): diff --git a/vall_e/webui.py b/vall_e/webui.py index b659198..b49716a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -348,7 +348,7 @@ with ui: #layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): - layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") + layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy* sample)") layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)") with gr.Row(): #layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")