From 732deaa2123897541f83cac1724f7e5e77f45592 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 12 Apr 2022 20:53:09 -0600 Subject: [PATCH] support latents into the diffusion decoder --- api.py | 21 +++++++++++------- eval_multiple.py | 2 +- models/autoregressive.py | 44 ++++++++++++++++++++++++------------- models/diffusion_decoder.py | 17 +++++++++----- 4 files changed, 55 insertions(+), 29 deletions(-) diff --git a/api.py b/api.py index f5b2cd6..204c91f 100644 --- a/api.py +++ b/api.py @@ -117,7 +117,7 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_ cond_mels.append(cond_mel) cond_mels = torch.stack(cond_mels, dim=1) - output_seq_len = mel_codes.shape[-1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_seq_len = mel_codes.shape[1]*4*24000//22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_shape = (mel_codes.shape[0], 100, output_seq_len) precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, output_seq_len, False) @@ -151,11 +151,6 @@ class TextToSpeech: layer_drop=0, unconditioned_percentage=0).cpu().eval() self.diffusion.load_state_dict(torch.load('.models/diffusion.pth')) - self.diffusion_next = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200, - in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, - layer_drop=0, unconditioned_percentage=0).cpu().eval() - self.diffusion_next.load_state_dict(torch.load('.models/diffusion_next.pth')) - self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load('.models/vocoder.pth')['model_g']) self.vocoder.eval(inference=True) @@ -223,12 +218,22 @@ class TextToSpeech: self.clip = self.clip.cpu() del samples + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + self.autoregressive = self.autoregressive.cuda() + best_latents = self.autoregressive(conds, text, torch.tensor([text.shape[-1]], device=conds.device), best_results, + torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=conds.device), + return_latent=True, clip_inputs=False) + self.autoregressive = self.autoregressive.cpu() + print("Performing vocoding..") wav_candidates = [] self.diffusion = self.diffusion.cuda() self.vocoder = self.vocoder.cuda() for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) # Find the first occurrence of the "calm" token and trim the codes to that. ctokens = 0 @@ -238,10 +243,10 @@ class TextToSpeech: else: ctokens = 0 if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. - codes = codes[:, :k] + latents = latents[:, :k] break - mel = do_spectrogram_diffusion(self.diffusion, diffuser, codes, voice_samples, temperature=diffusion_temperature) + mel = do_spectrogram_diffusion(self.diffusion, diffuser, latents, voice_samples, temperature=diffusion_temperature) wav = self.vocoder.inference(mel) wav_candidates.append(wav.cpu()) self.diffusion = self.diffusion.cpu() diff --git a/eval_multiple.py b/eval_multiple.py index c55cdc1..9f1919d 100644 --- a/eval_multiple.py +++ b/eval_multiple.py @@ -7,7 +7,7 @@ from utils.audio import load_audio if __name__ == '__main__': fname = 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv' - outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_auto_256_samp_100_di_4' + outpath = 'D:\\tmp\\tortoise-tts-eval\\diverse_new_decoder_1' outpath_real = 'D:\\tmp\\tortoise-tts-eval\\real' os.makedirs(outpath, exist_ok=True) diff --git a/models/autoregressive.py b/models/autoregressive.py index 6f40ca7..64fd451 100644 --- a/models/autoregressive.py +++ b/models/autoregressive.py @@ -362,7 +362,7 @@ class UnifiedVoice(nn.Module): mel_input_tokens[b, actual_end:] = self.stop_mel_token return mel_input_tokens - def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): + def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False): if second_inputs is not None: emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) else: @@ -374,6 +374,10 @@ class UnifiedVoice(nn.Module): enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = self.final_norm(enc) + + if return_latent: + return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] + first_logits = enc[:, :first_inputs.shape[1]] first_logits = first_head(first_logits) first_logits = first_logits.permute(0,2,1) @@ -385,7 +389,8 @@ class UnifiedVoice(nn.Module): else: return first_logits - def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, + return_latent=False, clip_inputs=True): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). @@ -396,19 +401,23 @@ class UnifiedVoice(nn.Module): mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) raw_mels: MEL float tensor (b,80,s) - """ - assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' - assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] @@ -427,10 +436,15 @@ class UnifiedVoice(nn.Module): mel_inp = mel_codes mel_emb = self.mel_embedding(mel_inp) mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + if text_first: - text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) + text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. else: - mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) + mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) + if return_latent: + return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. if return_attentions: return mel_logits diff --git a/models/diffusion_decoder.py b/models/diffusion_decoder.py index 1baf809..5fdf7ad 100644 --- a/models/diffusion_decoder.py +++ b/models/diffusion_decoder.py @@ -176,7 +176,13 @@ class DiffusionTts(nn.Module): AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), ) self.code_norm = normalization(model_channels) - self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1) + self.latent_conditioner = nn.Sequential( + nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), @@ -190,6 +196,7 @@ class DiffusionTts(nn.Module): DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads), ) + self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) @@ -206,7 +213,7 @@ class DiffusionTts(nn.Module): groups = { 'minicoder': list(self.contextual_embedder.parameters()), 'layers': list(self.layers.parameters()), - 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()), + 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()), 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), 'time_embed': list(self.time_embed.parameters()), } @@ -227,7 +234,7 @@ class DiffusionTts(nn.Module): cond_emb = conds.mean(dim=-1) cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) if is_latent(aligned_conditioning): - code_emb = self.autoregressive_latent_converter(aligned_conditioning) + code_emb = self.latent_conditioner(aligned_conditioning) else: code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_converter(code_emb) @@ -269,7 +276,7 @@ class DiffusionTts(nn.Module): if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) - unused_params.extend(list(self.latent_converter.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_aligned_embeddings is not None: code_emb = precomputed_aligned_embeddings @@ -278,7 +285,7 @@ class DiffusionTts(nn.Module): if is_latent(aligned_conditioning): unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - unused_params.extend(list(self.latent_converter.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) unused_params.append(self.unconditioned_embedding)