This commit is contained in:
James Betker 2022-04-01 16:03:07 -06:00
parent f6a8b0a5ca
commit 035bcd9f6c

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
import torchaudio import torchaudio
from trainer.inject import Injector from trainer.inject import Injector
from utils.util import opt_get, load_model_from_config from utils.util import opt_get, load_model_from_config, pad_or_truncate
TACOTRON_MEL_MAX = 2.3143386840820312 TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = -11.512925148010254 TACOTRON_MEL_MIN = -11.512925148010254
@ -173,17 +173,16 @@ class GptVoiceLatentInjector(Injector):
def forward(self, state): def forward(self, state):
with torch.no_grad(): with torch.no_grad():
mel_inputs = self.to_mel(state[self.input]) mel_inputs = self.to_mel(state[self.input])
mel_cond = self.to_mel(state[self.conditioning_key]) state_cond = pad_or_truncate(state[self.conditioning_key], 88000)
mel_conds = []
for k in range(state_cond.shape[1]):
mel_conds.append(self.to_mel(state_cond[:, k]))
mel_conds = torch.stack(mel_conds, dim=1)
# Use the input as a conditioning input as well. This is fine because we are not actually training the GPT network so it can't learn to cheat.
max_mel_len = max(mel_inputs.shape[-1], mel_cond.shape[-1])
mel_cond = F.pad(mel_cond, (0, max_mel_len-mel_cond.shape[-1]))
mel_cond2 = F.pad(mel_inputs, (0, max_mel_len-mel_inputs.shape[-1]))
mel_cond = torch.cat([mel_cond.unsqueeze(1), mel_cond2.unsqueeze(1)], dim=1)
self.dvae = self.dvae.to(mel_inputs.device) self.dvae = self.dvae.to(mel_inputs.device)
codes = self.dvae.get_codebook_indices(mel_inputs) codes = self.dvae.get_codebook_indices(mel_inputs)
self.gpt = self.gpt.to(codes.device) self.gpt = self.gpt.to(codes.device)
latents = self.gpt.forward(mel_cond, state[self.text_input_key], latents = self.gpt.forward(mel_conds, state[self.text_input_key],
state[self.text_lengths_key], codes, state[self.input_lengths_key], state[self.text_lengths_key], codes, state[self.input_lengths_key],
text_first=True, raw_mels=None, return_attentions=False, return_latent=True) text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
return {self.output: latents} return {self.output: latents}