fwd fix
This commit is contained in:
parent
f6a8b0a5ca
commit
035bcd9f6c
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
import torchaudio
|
||||
|
||||
from trainer.inject import Injector
|
||||
from utils.util import opt_get, load_model_from_config
|
||||
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
||||
|
||||
TACOTRON_MEL_MAX = 2.3143386840820312
|
||||
TACOTRON_MEL_MIN = -11.512925148010254
|
||||
|
@ -173,17 +173,16 @@ class GptVoiceLatentInjector(Injector):
|
|||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
mel_inputs = self.to_mel(state[self.input])
|
||||
mel_cond = self.to_mel(state[self.conditioning_key])
|
||||
state_cond = pad_or_truncate(state[self.conditioning_key], 88000)
|
||||
mel_conds = []
|
||||
for k in range(state_cond.shape[1]):
|
||||
mel_conds.append(self.to_mel(state_cond[:, k]))
|
||||
mel_conds = torch.stack(mel_conds, dim=1)
|
||||
|
||||
# Use the input as a conditioning input as well. This is fine because we are not actually training the GPT network so it can't learn to cheat.
|
||||
max_mel_len = max(mel_inputs.shape[-1], mel_cond.shape[-1])
|
||||
mel_cond = F.pad(mel_cond, (0, max_mel_len-mel_cond.shape[-1]))
|
||||
mel_cond2 = F.pad(mel_inputs, (0, max_mel_len-mel_inputs.shape[-1]))
|
||||
mel_cond = torch.cat([mel_cond.unsqueeze(1), mel_cond2.unsqueeze(1)], dim=1)
|
||||
self.dvae = self.dvae.to(mel_inputs.device)
|
||||
codes = self.dvae.get_codebook_indices(mel_inputs)
|
||||
self.gpt = self.gpt.to(codes.device)
|
||||
latents = self.gpt.forward(mel_cond, state[self.text_input_key],
|
||||
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
|
||||
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
|
||||
return {self.output: latents}
|
||||
|
|
Loading…
Reference in New Issue
Block a user