This commit is contained in:
James Betker 2022-04-22 11:34:05 -06:00
parent da31baad21
commit 84d641c57a
2 changed files with 15 additions and 7 deletions

20
api.py
View File

@ -6,6 +6,7 @@ from urllib import request
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import progressbar import progressbar
import torchaudio
from models.cvvp import CVVP from models.cvvp import CVVP
from models.diffusion_decoder import DiffusionTts from models.diffusion_decoder import DiffusionTts
@ -118,29 +119,36 @@ def fix_autoregressive_output(codes, stop_token, complain=True):
return codes return codes
def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1): def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_samples, temperature=1):
""" """
Uses the specified diffusion model to convert discrete codes into a spectrogram. Uses the specified diffusion model to convert discrete codes into a spectrogram.
""" """
with torch.no_grad(): with torch.no_grad():
cond_mels = [] cond_mels = []
for sample in conditioning_samples: for sample in conditioning_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400) sample = pad_or_truncate(sample, 102400)
cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False) cond_mel = wav_to_univnet_mel(sample.to(latents.device), do_normalization=False)
cond_mels.append(cond_mel) cond_mels.append(cond_mel)
cond_mels = torch.stack(cond_mels, dim=1) cond_mels = torch.stack(cond_mels, dim=1)
output_seq_len = mel_codes.shape[1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (mel_codes.shape[0], 100, output_seq_len) output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False) precomputed_embeddings = diffusion_model.timestep_independent(latents, cond_mels, output_seq_len, False)
noise = torch.randn(output_shape, device=mel_codes.device) * temperature noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise, mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}) model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
class TextToSpeech: class TextToSpeech:
"""
Main entry point into Tortoise.
:param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
GPU OOM errors. Larger numbers generates slightly faster.
"""
def __init__(self, autoregressive_batch_size=16): def __init__(self, autoregressive_batch_size=16):
self.autoregressive_batch_size = autoregressive_batch_size self.autoregressive_batch_size = autoregressive_batch_size
self.tokenizer = VoiceBpeTokenizer() self.tokenizer = VoiceBpeTokenizer()

View File

@ -356,7 +356,7 @@ class UnifiedVoice(nn.Module):
preformatting to create a working TTS model. preformatting to create a working TTS model.
""" """
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>). # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = wav_lengths // self.mel_length_compression mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
for b in range(len(mel_lengths)): for b in range(len(mel_lengths)):
actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token.
if actual_end < mel_input_tokens.shape[-1]: if actual_end < mel_input_tokens.shape[-1]: