added flag (--cond-latent-max-chunk-size) that should restrict the maximum chunk size when chunking for calculating conditional latents, to avoid OOMing on VRAM

This commit is contained in:
mrq 2023-02-06 05:10:07 +00:00
parent a1f3b6a4da
commit b441a84615
2 changed files with 13 additions and 9 deletions

3
app.py
View File

@ -27,7 +27,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, preset, seed, c
if voice_samples is not None: if voice_samples is not None:
sample_voice = voice_samples[0] sample_voice = voice_samples[0]
conditioning_latents = tts.get_conditioning_latents(voice_samples, progress=progress) conditioning_latents = tts.get_conditioning_latents(voice_samples, progress=progress, max_chunk_size=args.cond_latent_max_chunk_size)
torch.save(conditioning_latents, os.path.join(f'./tortoise/voices/{voice}/', f'cond_latents.pth')) torch.save(conditioning_latents, os.path.join(f'./tortoise/voices/{voice}/', f'cond_latents.pth'))
voice_samples = None voice_samples = None
else: else:
@ -265,6 +265,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', help="Lets Gradio return a public URL to use anywhere") parser.add_argument("--share", action='store_true', help="Lets Gradio return a public URL to use anywhere")
parser.add_argument("--low-vram", action='store_true', help="Disables some optimizations that increases VRAM usage") parser.add_argument("--low-vram", action='store_true', help="Disables some optimizations that increases VRAM usage")
parser.add_argument("--cond-latent-max-chunk-size", type=int, default=None, help="Sets an upper limit to audio chunk size when computing conditioning latents")
args = parser.parse_args() args = parser.parse_args()
tts = TextToSpeech(minor_optimizations=not args.low_vram) tts = TextToSpeech(minor_optimizations=not args.low_vram)

View File

@ -284,7 +284,7 @@ class TextToSpeech:
if self.minor_optimizations: if self.minor_optimizations:
self.cvvp = self.cvvp.to(self.device) self.cvvp = self.cvvp.to(self.device)
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, enforced_length=None, chunk_tensors=False): def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, chunk_size=None, max_chunk_size=None, chunk_tensors=True):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
@ -309,26 +309,29 @@ class TextToSpeech:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs) # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
samples.append(torchaudio.functional.resample(sample, 22050, 24000)) samples.append(torchaudio.functional.resample(sample, 22050, 24000))
if enforced_length is None: if chunk_size is None:
for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."): for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."):
if chunk_tensors: if chunk_tensors:
enforced_length = sample.shape[-1] if enforced_length is None else min( enforced_length, sample.shape[-1] ) chunk_size = sample.shape[-1] if chunk_size is None else min( chunk_size, sample.shape[-1] )
else: else:
enforced_length = sample.shape[-1] if enforced_length is None else max( enforced_length, sample.shape[-1] ) chunk_size = sample.shape[-1] if chunk_size is None else max( chunk_size, sample.shape[-1] )
print(f"Size of best fit: {enforced_length}") print(f"Size of best fit: {chunk_size}")
if max_chunk_size is not None and chunk_size > max_chunk_size:
chunk_size = max_chunk_size
print(f"Chunk size exceeded, clamping to: {max_chunk_size}")
chunks = [] chunks = []
if chunk_tensors: if chunk_tensors:
for sample in tqdm_override(samples, verbose=verbose, progress=progress, desc="Slicing samples into chunks..."): for sample in tqdm_override(samples, verbose=verbose, progress=progress, desc="Slicing samples into chunks..."):
sliced = torch.chunk(sample, int(sample.shape[-1] / enforced_length) + 1, dim=1) sliced = torch.chunk(sample, int(sample.shape[-1] / chunk_size) + 1, dim=1)
for s in sliced: for s in sliced:
chunks.append(s) chunks.append(s)
else: else:
chunks = samples chunks = samples
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
chunk = pad_or_truncate(chunk, enforced_length) chunk = pad_or_truncate(chunk, chunk_size)
cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device) cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)