1
1
forked from mrq/tortoise-tts

fixed up the computing conditional latents

This commit is contained in:
mrq 2023-02-06 03:44:34 +00:00
parent 3c0648beaf
commit 319e7ec0a6
3 changed files with 32 additions and 13 deletions

View File

@ -284,7 +284,7 @@ class TextToSpeech:
if self.minor_optimizations: if self.minor_optimizations:
self.cvvp = self.cvvp.to(self.device) self.cvvp = self.cvvp.to(self.device)
def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, enforced_length=102400): def get_conditioning_latents(self, voice_samples, return_mels=False, verbose=False, progress=None, enforced_length=None, chunk_tensors=False):
""" """
Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent).
These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic
@ -304,12 +304,30 @@ class TextToSpeech:
diffusion_conds = [] diffusion_conds = []
for sample in tqdm_override(voice_samples, verbose=verbose, progress=progress, desc="Computing conditioning latents..."): samples = [] # resample in its own pass to make things easier
for sample in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs) # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000) samples.append(torchaudio.functional.resample(sample, 22050, 24000))
chunks = torch.chunk(sample, int(sample.shape[-1] / enforced_length) + 1, dim=1)
for chunk in chunks: if enforced_length is None:
for sample in tqdm_override(samples, verbose=verbose and len(samples) > 1, progress=progress if len(samples) > 1 else None, desc="Calculating size of best fit..."):
if chunk_tensors:
enforced_length = sample.shape[-1] if enforced_length is None else min( enforced_length, sample.shape[-1] )
else:
enforced_length = sample.shape[-1] if enforced_length is None else max( enforced_length, sample.shape[-1] )
print(f"Size of best fit: {enforced_length}")
chunks = []
if chunk_tensors:
for sample in tqdm_override(samples, verbose=verbose, progress=progress, desc="Slicing samples into chunks..."):
sliced = torch.chunk(sample, int(sample.shape[-1] / enforced_length) + 1, dim=1)
for s in sliced:
chunks.append(s)
else:
chunks = samples
for chunk in tqdm_override(chunks, verbose=verbose, progress=progress, desc="Computing conditioning latents..."):
chunk = pad_or_truncate(chunk, enforced_length) chunk = pad_or_truncate(chunk, enforced_length)
cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device) cond_mel = wav_to_univnet_mel(chunk.to(self.device), do_normalization=False, device=self.device)
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
@ -424,6 +442,7 @@ class TextToSpeech:
:return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. :return: Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz. Sample rate is 24kHz.
""" """
self.diffusion.enable_fp16 = half_p
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
@ -432,7 +451,7 @@ class TextToSpeech:
auto_conds = None auto_conds = None
if voice_samples is not None: if voice_samples is not None:
auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True) auto_conditioning, diffusion_conditioning, auto_conds, _ = self.get_conditioning_latents(voice_samples, return_mels=True, verbose=True)
elif conditioning_latents is not None: elif conditioning_latents is not None:
auto_conditioning, diffusion_conditioning = conditioning_latents auto_conditioning, diffusion_conditioning = conditioning_latents
else: else:

2
tortoise/get_conditioning_latents.py Normal file → Executable file
View File

@ -25,6 +25,6 @@ if __name__ == '__main__':
for cond_path in cond_paths: for cond_path in cond_paths:
c = load_audio(cond_path, 22050) c = load_audio(cond_path, 22050)
conds.append(c) conds.append(c)
conditioning_latents = tts.get_conditioning_latents(conds) conditioning_latents = tts.get_conditioning_latents(conds, verbose=True)
torch.save(conditioning_latents, os.path.join(args.output_path, f'{voice}.pth')) torch.save(conditioning_latents, os.path.join(args.output_path, f'{voice}.pth'))

2
tortoise/models/diffusion_decoder.py Normal file → Executable file
View File

@ -141,7 +141,7 @@ class DiffusionTts(nn.Module):
in_tokens=8193, in_tokens=8193,
out_channels=200, # mean and variance out_channels=200, # mean and variance
dropout=0, dropout=0,
use_fp16=False, use_fp16=True,
num_heads=16, num_heads=16,
# Parameters for regularization. # Parameters for regularization.
layer_drop=.1, layer_drop=.1,