From 035bcd9f6ccf1e66963126467054acbd6da24aa4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Apr 2022 16:03:07 -0600 Subject: [PATCH] fwd fix --- codes/trainer/injectors/audio_injectors.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 24a83bee..3120cb39 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -5,7 +5,7 @@ import torch.nn.functional as F import torchaudio from trainer.inject import Injector -from utils.util import opt_get, load_model_from_config +from utils.util import opt_get, load_model_from_config, pad_or_truncate TACOTRON_MEL_MAX = 2.3143386840820312 TACOTRON_MEL_MIN = -11.512925148010254 @@ -173,17 +173,16 @@ class GptVoiceLatentInjector(Injector): def forward(self, state): with torch.no_grad(): mel_inputs = self.to_mel(state[self.input]) - mel_cond = self.to_mel(state[self.conditioning_key]) + state_cond = pad_or_truncate(state[self.conditioning_key], 88000) + mel_conds = [] + for k in range(state_cond.shape[1]): + mel_conds.append(self.to_mel(state_cond[:, k])) + mel_conds = torch.stack(mel_conds, dim=1) - # Use the input as a conditioning input as well. This is fine because we are not actually training the GPT network so it can't learn to cheat. - max_mel_len = max(mel_inputs.shape[-1], mel_cond.shape[-1]) - mel_cond = F.pad(mel_cond, (0, max_mel_len-mel_cond.shape[-1])) - mel_cond2 = F.pad(mel_inputs, (0, max_mel_len-mel_inputs.shape[-1])) - mel_cond = torch.cat([mel_cond.unsqueeze(1), mel_cond2.unsqueeze(1)], dim=1) self.dvae = self.dvae.to(mel_inputs.device) codes = self.dvae.get_codebook_indices(mel_inputs) self.gpt = self.gpt.to(codes.device) - latents = self.gpt.forward(mel_cond, state[self.text_input_key], + latents = self.gpt.forward(mel_conds, state[self.text_input_key], state[self.text_lengths_key], codes, state[self.input_lengths_key], text_first=True, raw_mels=None, return_attentions=False, return_latent=True) return {self.output: latents}