Fixes for training the diffusion model on autoregressive inputs

This commit is contained in:
James Betker 2022-04-11 11:02:44 -06:00
parent a3622462c1
commit 8ea5c307fb
2 changed files with 23 additions and 15 deletions

View File

@ -353,7 +353,7 @@ class UnifiedVoice(nn.Module):
enc = self.final_norm(enc) enc = self.final_norm(enc)
if return_latent: if return_latent:
return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:] return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]] first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits) first_logits = first_head(first_logits)
@ -367,7 +367,7 @@ class UnifiedVoice(nn.Module):
return first_logits return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False, def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False): return_latent=False, clip_inputs=True):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -381,16 +381,20 @@ class UnifiedVoice(nn.Module):
If return_attentions is specified, only logits are returned. If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
""" """
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by if clip_inputs:
# chopping the inputs by the maximum actual length. # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
max_text_len = text_lengths.max() # chopping the inputs by the maximum actual length.
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) max_text_len = text_lengths.max()
max_mel_len = wav_lengths.max() // self.mel_length_compression text_inputs = text_inputs[:, :max_text_len]
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) max_mel_len = wav_lengths.max() // self.mel_length_compression
if raw_mels is not None: mel_codes = mel_codes[:, :max_mel_len]
raw_mels = raw_mels[:, :, :max_mel_len*4] if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths) mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = [] conds = []
@ -413,11 +417,11 @@ class UnifiedVoice(nn.Module):
if text_first: if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent) text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent: if return_latent:
return mel_logits[:, :-1] # Despite the name, these are not logits. return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else: else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent) mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent: if return_latent:
return text_logits[:, :-1] # Despite the name, these are not logits return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
if return_attentions: if return_attentions:
return mel_logits return mel_logits

View File

@ -159,6 +159,7 @@ class GptVoiceLatentInjector(Injector):
pretrained_path = opt['gpt_path'] pretrained_path = opt['gpt_path']
self.gpt = load_model_from_config(cfg, model_name=model_name, self.gpt = load_model_from_config(cfg, model_name=model_name,
also_load_savepoint=False, load_path=pretrained_path).cuda().eval() also_load_savepoint=False, load_path=pretrained_path).cuda().eval()
self.needs_move = True
# Mel converter # Mel converter
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{}) self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
# Aux input keys. # Aux input keys.
@ -179,10 +180,13 @@ class GptVoiceLatentInjector(Injector):
mel_conds.append(self.to_mel(state_cond[:, k])) mel_conds.append(self.to_mel(state_cond[:, k]))
mel_conds = torch.stack(mel_conds, dim=1) mel_conds = torch.stack(mel_conds, dim=1)
self.dvae = self.dvae.to(mel_inputs.device) if self.needs_move:
self.dvae = self.dvae.to(mel_inputs.device)
self.gpt = self.gpt.to(mel_inputs.device)
codes = self.dvae.get_codebook_indices(mel_inputs) codes = self.dvae.get_codebook_indices(mel_inputs)
self.gpt = self.gpt.to(codes.device)
latents = self.gpt.forward(mel_conds, state[self.text_input_key], latents = self.gpt.forward(mel_conds, state[self.text_input_key],
state[self.text_lengths_key], codes, state[self.input_lengths_key], state[self.text_lengths_key], codes, state[self.input_lengths_key],
text_first=True, raw_mels=None, return_attentions=False, return_latent=True) text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
clip_inputs=False)
assert latents.shape[1] == codes.shape[1]
return {self.output: latents} return {self.output: latents}