forked from mrq/tortoise-tts
added flag (--cond-latent-max-chunk-size) that should restrict the maximum chunk size when chunking for calculating conditional latents, to avoid OOMing on VRAM
This commit is contained in:
parent
319e7ec0a6
commit
b8b15d827d
3
app.py
3
app.py
|
@ -27,7 +27,7 @@ def generate(text, delimiter, emotion, prompt, voice, mic_audio, preset, seed, c
|
||||||
|
|
||||||
if voice_samples is not None:
|
if voice_samples is not None:
|
||||||
sample_voice = voice_samples[0]
|
sample_voice = voice_samples[0]
|
||||||
conditioning_latents = tts.get_conditioning_latents(voice_samples, progress=progress)
|
conditioning_latents = tts.get_conditioning_latents(voice_samples, progress=progress, max_chunk_size=args.cond_latent_max_chunk_size)
|
||||||
torch.save(conditioning_latents, os.path.join(f'./tortoise/voices/{voice}/', f'cond_latents.pth'))
|
torch.save(conditioning_latents, os.path.join(f'./tortoise/voices/{voice}/', f'cond_latents.pth'))
|
||||||
voice_samples = None
|
voice_samples = None
|
||||||
else:
|
else:
|
||||||
|
@ -265,6 +265,7 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--share", action='store_true', help="Lets Gradio return a public URL to use anywhere")
|
parser.add_argument("--share", action='store_true', help="Lets Gradio return a public URL to use anywhere")
|
||||||
parser.add_argument("--low-vram", action='store_true', help="Disables some optimizations that increases VRAM usage")
|
parser.add_argument("--low-vram", action='store_true', help="Disables some optimizations that increases VRAM usage")
|
||||||
|
parser.add_argument("--cond-latent-max-chunk-size", type=int, default=None, help="Sets an upper limit to audio chunk size when computing conditioning latents")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||||
|
|
|
@ -284,7 +284,7 @@ class TextToSpeech:
|
||||||
if self.minor_optimizations:
|
if self.minor_optimizations:
|
||||||
self.cvvp = self.cvvp.to(self.device)
|
self.cvvp = self.cvvp.to(self.device)
|
||||||
|
|
||||||
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, enforced_length=None, chunk_tensors=False):
|
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, chunk_size=None, max_chunk_size=None, chunk_tensors=True):
|
||||||
"""
|
"""
|
||||||
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
|
||||||
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
|
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
|
||||||
|
@ -309,26 +309,29 @@ class TextToSpeech:
|
||||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||||
samples.append(torchaudio.functional.resample(sample, 22050, 24000))
|
samples.append(torchaudio.functional.resample(sample, 22050, 24000))
|
||||||
|
|
||||||
if enforced_length is None:
|
if chunk_size is None:
|
||||||
for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."):
|
for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."):
|
||||||
if chunk_tensors:
|
if chunk_tensors:
|
||||||
enforced_length = sample.shape[-1] if enforced_length is None else min( enforced_length, sample.shape[-1] )
|
chunk_size = sample.shape[-1] if chunk_size is None else min( chunk_size, sample.shape[-1] )
|
||||||
else:
|
else:
|
||||||
enforced_length = sample.shape[-1] if enforced_length is None else max( enforced_length, sample.shape[-1] )
|
chunk_size = sample.shape[-1] if chunk_size is None else max( chunk_size, sample.shape[-1] )
|
||||||
|
|
||||||
print(f"Size of best fit: {enforced_length}")
|
print(f"Size of best fit: {chunk_size}")
|
||||||
|
if max_chunk_size is not None and chunk_size > max_chunk_size:
|
||||||
|
chunk_size = max_chunk_size
|
||||||
|
print(f"Chunk size exceeded, clamping to: {max_chunk_size}")
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
if chunk_tensors:
|
if chunk_tensors:
|
||||||
for sample in tqdm_override(samples, verbose=verbose, progress=progress, desc="Slicing samples into chunks..."):
|
for sample in tqdm_override(samples, verbose=verbose, progress=progress, desc="Slicing samples into chunks..."):
|
||||||
sliced = torch.chunk(sample, int(sample.shape[-1] / enforced_length) + 1, dim=1)
|
sliced = torch.chunk(sample, int(sample.shape[-1] / chunk_size) + 1, dim=1)
|
||||||
for s in sliced:
|
for s in sliced:
|
||||||
chunks.append(s)
|
chunks.append(s)
|
||||||
else:
|
else:
|
||||||
chunks = samples
|
chunks = samples
|
||||||
|
|
||||||
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
|
||||||
chunk = pad_or_truncate(chunk, enforced_length)
|
chunk = pad_or_truncate(chunk, chunk_size)
|
||||||
cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device)
|
cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device)
|
||||||
diffusion_conds.append(cond_mel)
|
diffusion_conds.append(cond_mel)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user