From 73f271fb8a29f4574be280808163cc23fded3e8e Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 17:01:05 -0500 Subject: [PATCH] added automagic offloading models to GPU then CPU when theyre done during inference --- README.md | 2 +- tortoise_tts/__main__.py | 2 +- tortoise_tts/config.py | 13 +--- tortoise_tts/inference.py | 135 ++++++++++++++++++---------------- tortoise_tts/train.py | 4 + tortoise_tts/utils/wrapper.py | 8 ++ tortoise_tts/webui.py | 2 +- 7 files changed, 90 insertions(+), 76 deletions(-) diff --git a/README.md b/README.md index 62b52fc..38564e3 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ For training a LoRA, uncomment the `loras` block in your training YAML. - [ ] Reimplement redaction with the Wav2Vec2 - [X] Implement training support (without DLAS) - [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time -- [ ] Automagic offloading to CPU for unused models (for training and inferencing) +- [X] Automagic offloading to CPU for unused models (for training and inferencing) - [X] Automagic handling of the original weights into compatible weights - [ ] Reimplement added features from my original fork: - [ ] "Better" conditioning latents calculating diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py index a8d17c9..7632e8a 100755 --- a/tortoise_tts/__main__.py +++ b/tortoise_tts/__main__.py @@ -19,7 +19,7 @@ def main(): parser.add_argument("--top-k", type=int, default=16) parser.add_argument("--repetition-penalty", type=float, default=1.0) #parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) - parser.add_argument("--length-penalty", type=float, default=0.0) + parser.add_argument("--length-penalty", type=float, default=1.0) parser.add_argument("--beam-width", type=int, default=0) parser.add_argument("--diffusion-sampler", type=str, default="ddim") diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 9ff4254..485f70a 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -21,8 +21,6 @@ from .tokenizer import VoiceBpeTokenizer # Yuck from transformers import PreTrainedTokenizerFast -from tokenizers import Tokenizer - @dataclass() class BaseConfig: @@ -472,17 +470,10 @@ class Inference: weight_dtype: str = "float32" amp: bool = False + auto_unload: bool = True + normalize: bool = False # do NOT enable this unless you know exactly what you're doing - # legacy / backwards compat - use_vocos: bool = True - use_encodec: bool = True - use_dac: bool = True - - # shit that doesn't work - recurrent_chunk_size: int = 0 - recurrent_forward: bool = False - @cached_property def dtype(self): if self.weight_dtype == "float16": diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index e6ca975..373a491 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -8,6 +8,7 @@ from pathlib import Path from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .utils import to_device +from .utils import wrapper as ml from .config import cfg from .models import get_models, load_model @@ -110,7 +111,7 @@ class TTS(): top_k=0, repetition_penalty=1.0, #repetition_penalty_decay=0.0, - length_penalty=0.0, + length_penalty=1.0, beam_width=1, #mirostat_tau=0, #mirostat_eta=0.1, @@ -151,6 +152,13 @@ class TTS(): if vocoder is None: vocoder = load_model("vocoder", device=cfg.device) + # shove everything to cpu + if cfg.inference.auto_unload: + autoregressive = autoregressive.to("cpu") + diffusion = diffusion.to("cpu") + clvp = clvp.to("cpu") + vocoder = vocoder.to("cpu") + wavs = [] # other vars calm_token = 832 @@ -168,79 +176,82 @@ class TTS(): text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32) with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - # autoregressive pass - codes = autoregressive.inference_speech( - autoregressive_latents, - text_tokens, - do_sample=True, - top_p=top_p, - temperature=ar_temp, - num_return_sequences=1, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - max_generate_length=max_ar_steps, - ) - - """ - padding_needed = max_ar_steps - codes.shape[1] - codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token) - """ + with ml.auto_unload(autoregressive, enabled=cfg.inference.auto_unload): + # autoregressive pass + codes = autoregressive.inference_speech( + autoregressive_latents, + text_tokens, + do_sample=True, + top_p=top_p, + temperature=ar_temp, + num_return_sequences=1, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_ar_steps, + ) - for i, code in enumerate( codes ): - stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero() - stm = stop_token_indices.min().item() + """ + padding_needed = max_ar_steps - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token) + """ - if len(stop_token_indices) == 0: - continue + for i, code in enumerate( codes ): + stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero() + stm = stop_token_indices.min().item() - codes[i][stop_token_indices] = 83 - codes[i][stm:] = 83 + if len(stop_token_indices) == 0: + continue - if stm - 3 < codes[i].shape[0]: - codes[i][-3] = 45 - codes[i][-2] = 45 - codes[i][-1] = 248 + codes[i][stop_token_indices] = 83 + codes[i][stm:] = 83 - wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) + if stm - 3 < codes[i].shape[0]: + codes[i][-3] = 45 + codes[i][-2] = 45 + codes[i][-1] = 248 - latents = autoregressive.forward( - autoregressive_latents, - text_tokens, - text_lengths, - codes, - wav_lengths, - return_latent=True, - clip_inputs=False - ) + wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) - calm_tokens = 0 - for k in range( codes.shape[-1] ): - if codes[0, k] == calm_token: - calm_tokens += 1 - else: - calm_tokens = 0 - if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. - latents = latents[:, :k] - break + latents = autoregressive.forward( + autoregressive_latents, + text_tokens, + text_lengths, + codes, + wav_lengths, + return_latent=True, + clip_inputs=False + ) + + calm_tokens = 0 + for k in range( codes.shape[-1] ): + if codes[0, k] == calm_token: + calm_tokens += 1 + else: + calm_tokens = 0 + if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :k] + break # diffusion pass - output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. - output_shape = (latents.shape[0], 100, output_seq_len) - precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) + with ml.auto_unload(diffusion, enabled=cfg.inference.auto_unload): + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) - noise = torch.randn(output_shape, device=latents.device) * diffusion_temp - mel = diffuser.sample_loop( - diffusion, - output_shape, - sampler=diffusion_sampler, - noise=noise, - model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, - progress=True - ) - mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + noise = torch.randn(output_shape, device=latents.device) * diffusion_temp + mel = diffuser.sample_loop( + diffusion, + output_shape, + sampler=diffusion_sampler, + noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=True + ) + mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] # vocoder pass - waves = vocoder.inference(mels) + with ml.auto_unload(vocoder, enabled=cfg.inference.auto_unload): + waves = vocoder.inference(mels) for wav in waves: if out_path is not None: diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index 2b45ba0..c796c86 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -229,6 +229,10 @@ def run_eval(engines, eval_name, dl): else: _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") + diffusion = diffusion.to("cpu") + clvp = clvp.to("cpu") + vocoder = vocoder.to("cpu") + def train(): parser = argparse.ArgumentParser("TorToiSe TTS") diff --git a/tortoise_tts/utils/wrapper.py b/tortoise_tts/utils/wrapper.py index 275d5a4..3320b9d 100755 --- a/tortoise_tts/utils/wrapper.py +++ b/tortoise_tts/utils/wrapper.py @@ -77,6 +77,14 @@ def autocasts(input, from_dtype, to_dtype): else: yield input +@contextmanager +def auto_unload( model, gpu="cuda", cpu="cpu", enabled=True): + model.to(gpu) + yield model + + if enabled: + model.to(cpu) + # handles temporarily upcasting 'index tensors' so torch will stop bitching def autocast_forward( func ): def wrapper( self, input, *args, **kwargs ): diff --git a/tortoise_tts/webui.py b/tortoise_tts/webui.py index 868712d..71fe38d 100644 --- a/tortoise_tts/webui.py +++ b/tortoise_tts/webui.py @@ -240,7 +240,7 @@ with ui: with gr.Row(): layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") - layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") """ with gr.Row(): layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")