From 8ea5c307fb76bf2d4faea05c1a2404b3c2ab9b81 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 11 Apr 2022 11:02:44 -0600 Subject: [PATCH] Fixes for training the diffusion model on autoregressive inputs --- codes/models/audio/tts/unified_voice2.py | 28 ++++++++++++---------- codes/trainer/injectors/audio_injectors.py | 10 +++++--- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/codes/models/audio/tts/unified_voice2.py b/codes/models/audio/tts/unified_voice2.py index 435526c6..22d6fce6 100644 --- a/codes/models/audio/tts/unified_voice2.py +++ b/codes/models/audio/tts/unified_voice2.py @@ -353,7 +353,7 @@ class UnifiedVoice(nn.Module): enc = self.final_norm(enc) if return_latent: - return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] + return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] first_logits = enc[:, :first_inputs.shape[1]] first_logits = first_head(first_logits) @@ -367,7 +367,7 @@ class UnifiedVoice(nn.Module): return first_logits def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, - return_latent=False): + return_latent=False, clip_inputs=True): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). @@ -381,16 +381,20 @@ class UnifiedVoice(nn.Module): If return_attentions is specified, only logits are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. """ - # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by - # chopping the inputs by the maximum actual length. - max_text_len = text_lengths.max() - text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) - max_mel_len = wav_lengths.max() // self.mel_length_compression - mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) - if raw_mels is not None: - raw_mels = raw_mels[:, :, :max_mel_len*4] + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token) speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input conds = [] @@ -413,11 +417,11 @@ class UnifiedVoice(nn.Module): if text_first: text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return mel_logits[:, :-1] # Despite the name, these are not logits. + return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. else: mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) if return_latent: - return text_logits[:, :-1] # Despite the name, these are not logits + return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. if return_attentions: return mel_logits diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 3120cb39..9bcfb4cd 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -159,6 +159,7 @@ class GptVoiceLatentInjector(Injector): pretrained_path = opt['gpt_path'] self.gpt = load_model_from_config(cfg, model_name=model_name, also_load_savepoint=False, load_path=pretrained_path).cuda().eval() + self.needs_move = True # Mel converter self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{}) # Aux input keys. @@ -179,10 +180,13 @@ class GptVoiceLatentInjector(Injector): mel_conds.append(self.to_mel(state_cond[:, k])) mel_conds = torch.stack(mel_conds, dim=1) - self.dvae = self.dvae.to(mel_inputs.device) + if self.needs_move: + self.dvae = self.dvae.to(mel_inputs.device) + self.gpt = self.gpt.to(mel_inputs.device) codes = self.dvae.get_codebook_indices(mel_inputs) - self.gpt = self.gpt.to(codes.device) latents = self.gpt.forward(mel_conds, state[self.text_input_key], state[self.text_lengths_key], codes, state[self.input_lengths_key], - text_first=True, raw_mels=None, return_attentions=False, return_latent=True) + text_first=True, raw_mels=None, return_attentions=False, return_latent=True, + clip_inputs=False) + assert latents.shape[1] == codes.shape[1] return {self.output: latents}