From b8b15d827d21d1152bfe61edc6e87a505fd07410 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 6 Feb 2023 05:10:07 +0000 Subject: [PATCH] added flag (--cond-latent-max-chunk-size) that should restrict the maximum chunk size when chunking for calculating conditional latents, to avoid OOMing on VRAM --- app.py | 3 ++- tortoise/api.py | 19 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/app.py b/app.py index 742ebdd..51825bb 100755 --- a/app.py +++ b/app.py @@ -27,7 +27,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, preset, seed, c if voice_samples is not None: sample_voice = voice_samples[0] - conditioning_latents = tts.get_conditioning_latents(voice_samples, progress=progress) + conditioning_latents = tts.get_conditioning_latents(voice_samples, progress=progress, max_chunk_size=args.cond_latent_max_chunk_size) torch.save(conditioning_latents, os.path.join(f'./tortoise/voices/{voice}/', f'cond_latents.pth')) voice_samples = None else: @@ -265,6 +265,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--share", action='store_true', help="Lets Gradio return a public URL to use anywhere") parser.add_argument("--low-vram", action='store_true', help="Disables some optimizations that increases VRAM usage") + parser.add_argument("--cond-latent-max-chunk-size", type=int, default=None, help="Sets an upper limit to audio chunk size when computing conditioning latents") args = parser.parse_args() tts = TextToSpeech(minor_optimizations=not args.low_vram) diff --git a/tortoise/api.py b/tortoise/api.py index f54e16c..dc14ff6 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -284,7 +284,7 @@ class TextToSpeech: if self.minor_optimizations: self.cvvp = self.cvvp.to(self.device) - def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, enforced_length=None, chunk_tensors=False): + def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, chunk_size=None, max_chunk_size=None, chunk_tensors=True): """ Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic @@ -309,26 +309,29 @@ class TextToSpeech: # The diffuser operates at a sample rate of 24000 (except for the latent inputs) samples.append(torchaudio.functional.resample(sample, 22050, 24000)) - if enforced_length is None: + if chunk_size is None: for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."): if chunk_tensors: - enforced_length = sample.shape[-1] if enforced_length is None else min( enforced_length, sample.shape[-1] ) + chunk_size = sample.shape[-1] if chunk_size is None else min( chunk_size, sample.shape[-1] ) else: - enforced_length = sample.shape[-1] if enforced_length is None else max( enforced_length, sample.shape[-1] ) + chunk_size = sample.shape[-1] if chunk_size is None else max( chunk_size, sample.shape[-1] ) - print(f"Size of best fit: {enforced_length}") + print(f"Size of best fit: {chunk_size}") + if max_chunk_size is not None and chunk_size > max_chunk_size: + chunk_size = max_chunk_size + print(f"Chunk size exceeded, clamping to: {max_chunk_size}") chunks = [] if chunk_tensors: for sample in tqdm_override(samples, verbose=verbose, progress=progress, desc="Slicing samples into chunks..."): - sliced = torch.chunk(sample, int(sample.shape[-1] / enforced_length) + 1, dim=1) + sliced = torch.chunk(sample, int(sample.shape[-1] / chunk_size) + 1, dim=1) for s in sliced: chunks.append(s) else: chunks = samples for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): - chunk = pad_or_truncate(chunk, enforced_length) + chunk = pad_or_truncate(chunk, chunk_size) cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device) diffusion_conds.append(cond_mel) @@ -460,7 +463,7 @@ class TextToSpeech: diffusion_conditioning = diffusion_conditioning.to(self.device) diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k) - + with torch.no_grad(): samples = [] num_batches = num_autoregressive_samples // self.autoregressive_batch_size