From f4fcc35aa8d30ddf92ad1a86550e98157cdd182d Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 18:26:15 -0500 Subject: [PATCH] fixed it breaking on subsequent utterances through the web UI from latents being on the CPU --- tortoise_tts/__main__.py | 4 ++-- tortoise_tts/inference.py | 30 +++++++++++++++++++++--------- tortoise_tts/webui.py | 4 ++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py index d64b16e..56441fa 100755 --- a/tortoise_tts/__main__.py +++ b/tortoise_tts/__main__.py @@ -13,8 +13,8 @@ def main(): parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--max-ar-steps", type=int, default=500) parser.add_argument("--max-diffusion-steps", type=int, default=80) - parser.add_argument("--ar-temp", type=float, default=1.0) - parser.add_argument("--diffusion-temp", type=float, default=0.01) + parser.add_argument("--ar-temp", type=float, default=0.8) + parser.add_argument("--diffusion-temp", type=float, default=1.0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=16) parser.add_argument("--repetition-penalty", type=float, default=1.0) diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index 6ae8852..d117542 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -103,7 +103,7 @@ class TTS(): max_diffusion_steps=80, #max_ar_context=-1, #input_prompt_length=0.0, - ar_temp=1.0, + ar_temp=0.8, diffusion_temp=1.0, #min_ar_temp=0.95, #min_diffusion_temp=0.5, @@ -131,8 +131,6 @@ class TTS(): clvp = None vocoder = None diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=cond_free) - - autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"] for name, engine in self.engines.items(): if "autoregressive" in name: @@ -152,6 +150,10 @@ class TTS(): clvp = load_model("clvp", device=cfg.device) if vocoder is None: vocoder = load_model("vocoder", device=cfg.device) + + autoregressive = autoregressive.to(cfg.device) + diffusion = diffusion.to(cfg.device) + autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"] # shove everything to cpu if cfg.inference.auto_unload: @@ -164,6 +166,8 @@ class TTS(): # other vars calm_token = 832 + candidates = 1 + for line in lines: if out_path is None: output_dir = Path("./data/results/") @@ -185,7 +189,7 @@ class TTS(): do_sample=True, top_p=top_p, temperature=ar_temp, - num_return_sequences=1, + num_return_sequences=candidates, length_penalty=length_penalty, repetition_penalty=repetition_penalty, max_generate_length=max_ar_steps, @@ -213,12 +217,13 @@ class TTS(): wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) + # to-do: actually move this after the CLVP to get the best latents instead latents = autoregressive.forward( - autoregressive_latents, - text_tokens, - text_lengths, + autoregressive_latents if candidates <= 1 else autoregressive_latents.repeat(candidates, 1), + text_tokens if candidates <= 1 else text_tokens.repeat(candidates, 1), + text_lengths if candidates <= 1 else text_lengths.repeat(candidates, 1), codes, - wav_lengths, + wav_lengths if candidates <= 1 else wav_lengths.repeat(candidates, 1), return_latent=True, clip_inputs=False ) @@ -233,6 +238,13 @@ class TTS(): latents = latents[:, :k] break + # clvp pass + if candidates > 1: + with ml.auto_unload(clvp, enabled=cfg.inference.auto_unload): + scores = clvp(text_tokens.repeat(codes.shape[0], 1), codes, return_loss=False) + indices = torch.topk(scores, k=candidates).indices + codes = codes[indices] + # diffusion pass with ml.auto_unload(diffusion, enabled=cfg.inference.auto_unload): output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. @@ -258,6 +270,6 @@ class TTS(): if out_path is not None: torchaudio.save( out_path, wav.cpu(), sr ) wavs.append(wav) - + return (torch.concat(wavs, dim=-1), sr) diff --git a/tortoise_tts/webui.py b/tortoise_tts/webui.py index ca1486c..4969b40 100644 --- a/tortoise_tts/webui.py +++ b/tortoise_tts/webui.py @@ -219,8 +219,8 @@ with ui: layout["inference"]["inputs"]["diffusion-sampler"] = gr.Radio( ["P", "DDIM"], value="DDIM", label="Diffusion Samplers", type="value", info="Sampler to use during the diffusion pass." ) layout["inference"]["inputs"]["cond-free"] = gr.Checkbox(label="Cond. Free", value=True, info="Condition Free diffusion") with gr.Row(): - layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") - layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.") + layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.8, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)") + layout["inference"]["inputs"]["diffusion-temp"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (Diffusion)", info="Modifies the initial noise during the diffusion pass.") """ with gr.Row(): layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")