Fixes for training the diffusion model on autoregressive inputs

This commit is contained in:
James Betker 2022-04-11 11:02:44 -06:00
parent a3622462c1
commit 8ea5c307fb
2 changed files with 23 additions and 15 deletions

View File

@ -353,7 +353,7 @@ class UnifiedVoice(nn.Module):
enc = self.final_norm(enc)
if return_latent:
return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
first_logits = enc[:, :first_inputs.shape[1]]
first_logits = first_head(first_logits)
@ -367,7 +367,7 @@ class UnifiedVoice(nn.Module):
return first_logits
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
return_latent=False):
return_latent=False, clip_inputs=True):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
@ -381,16 +381,20 @@ class UnifiedVoice(nn.Module):
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
"""
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
if clip_inputs:
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
# chopping the inputs by the maximum actual length.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_length_compression
mel_codes = mel_codes[:, :max_mel_len]
if raw_mels is not None:
raw_mels = raw_mels[:, :, :max_mel_len*4]
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
conds = []
@ -413,11 +417,11 @@ class UnifiedVoice(nn.Module):
if text_first:
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return mel_logits[:, :-1] # Despite the name, these are not logits.
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
else:
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
if return_latent:
return text_logits[:, :-1] # Despite the name, these are not logits
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
if return_attentions:
return mel_logits

View File

@ -159,6 +159,7 @@ class GptVoiceLatentInjector(Injector):
pretrained_path = opt['gpt_path']
self.gpt = load_model_from_config(cfg, model_name=model_name,
also_load_savepoint=False, load_path=pretrained_path).cuda().eval()
self.needs_move = True
# Mel converter
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
# Aux input keys.
@ -179,10 +180,13 @@ class GptVoiceLatentInjector(Injector):
mel_conds.append(self.to_mel(state_cond[:, k]))
mel_conds = torch.stack(mel_conds, dim=1)
self.dvae = self.dvae.to(mel_inputs.device)
if self.needs_move:
self.dvae = self.dvae.to(mel_inputs.device)
self.gpt = self.gpt.to(mel_inputs.device)
codes = self.dvae.get_codebook_indices(mel_inputs)
self.gpt = self.gpt.to(codes.device)
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
state[self.text_lengths_key], codes, state[self.input_lengths_key],
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
clip_inputs=False)
assert latents.shape[1] == codes.shape[1]
return {self.output: latents}