fwd fix
This commit is contained in:
parent
f6a8b0a5ca
commit
035bcd9f6c
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.util import opt_get, load_model_from_config
|
from utils.util import opt_get, load_model_from_config, pad_or_truncate
|
||||||
|
|
||||||
TACOTRON_MEL_MAX = 2.3143386840820312
|
TACOTRON_MEL_MAX = 2.3143386840820312
|
||||||
TACOTRON_MEL_MIN = -11.512925148010254
|
TACOTRON_MEL_MIN = -11.512925148010254
|
||||||
|
@ -173,17 +173,16 @@ class GptVoiceLatentInjector(Injector):
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mel_inputs = self.to_mel(state[self.input])
|
mel_inputs = self.to_mel(state[self.input])
|
||||||
mel_cond = self.to_mel(state[self.conditioning_key])
|
state_cond = pad_or_truncate(state[self.conditioning_key], 88000)
|
||||||
|
mel_conds = []
|
||||||
|
for k in range(state_cond.shape[1]):
|
||||||
|
mel_conds.append(self.to_mel(state_cond[:, k]))
|
||||||
|
mel_conds = torch.stack(mel_conds, dim=1)
|
||||||
|
|
||||||
# Use the input as a conditioning input as well. This is fine because we are not actually training the GPT network so it can't learn to cheat.
|
|
||||||
max_mel_len = max(mel_inputs.shape[-1], mel_cond.shape[-1])
|
|
||||||
mel_cond = F.pad(mel_cond, (0, max_mel_len-mel_cond.shape[-1]))
|
|
||||||
mel_cond2 = F.pad(mel_inputs, (0, max_mel_len-mel_inputs.shape[-1]))
|
|
||||||
mel_cond = torch.cat([mel_cond.unsqueeze(1), mel_cond2.unsqueeze(1)], dim=1)
|
|
||||||
self.dvae = self.dvae.to(mel_inputs.device)
|
self.dvae = self.dvae.to(mel_inputs.device)
|
||||||
codes = self.dvae.get_codebook_indices(mel_inputs)
|
codes = self.dvae.get_codebook_indices(mel_inputs)
|
||||||
self.gpt = self.gpt.to(codes.device)
|
self.gpt = self.gpt.to(codes.device)
|
||||||
latents = self.gpt.forward(mel_cond, state[self.text_input_key],
|
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
|
||||||
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
||||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
|
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
|
||||||
return {self.output: latents}
|
return {self.output: latents}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user