Fixes for training the diffusion model on autoregressive inputs
This commit is contained in:
parent
a3622462c1
commit
8ea5c307fb
|
@ -353,7 +353,7 @@ class UnifiedVoice(nn.Module):
|
|||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
||||
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
||||
|
||||
first_logits = enc[:, :first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
|
@ -367,7 +367,7 @@ class UnifiedVoice(nn.Module):
|
|||
return first_logits
|
||||
|
||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
||||
return_latent=False):
|
||||
return_latent=False, clip_inputs=True):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
@ -381,16 +381,20 @@ class UnifiedVoice(nn.Module):
|
|||
|
||||
If return_attentions is specified, only logits are returned.
|
||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
||||
"""
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
if clip_inputs:
|
||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||
# chopping the inputs by the maximum actual length.
|
||||
max_text_len = text_lengths.max()
|
||||
text_inputs = text_inputs[:, :max_text_len]
|
||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||
mel_codes = mel_codes[:, :max_mel_len]
|
||||
if raw_mels is not None:
|
||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||
|
||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||
conds = []
|
||||
|
@ -413,11 +417,11 @@ class UnifiedVoice(nn.Module):
|
|||
if text_first:
|
||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return mel_logits[:, :-1] # Despite the name, these are not logits.
|
||||
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
else:
|
||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
||||
if return_latent:
|
||||
return text_logits[:, :-1] # Despite the name, these are not logits
|
||||
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||
|
||||
if return_attentions:
|
||||
return mel_logits
|
||||
|
|
|
@ -159,6 +159,7 @@ class GptVoiceLatentInjector(Injector):
|
|||
pretrained_path = opt['gpt_path']
|
||||
self.gpt = load_model_from_config(cfg, model_name=model_name,
|
||||
also_load_savepoint=False, load_path=pretrained_path).cuda().eval()
|
||||
self.needs_move = True
|
||||
# Mel converter
|
||||
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
||||
# Aux input keys.
|
||||
|
@ -179,10 +180,13 @@ class GptVoiceLatentInjector(Injector):
|
|||
mel_conds.append(self.to_mel(state_cond[:, k]))
|
||||
mel_conds = torch.stack(mel_conds, dim=1)
|
||||
|
||||
self.dvae = self.dvae.to(mel_inputs.device)
|
||||
if self.needs_move:
|
||||
self.dvae = self.dvae.to(mel_inputs.device)
|
||||
self.gpt = self.gpt.to(mel_inputs.device)
|
||||
codes = self.dvae.get_codebook_indices(mel_inputs)
|
||||
self.gpt = self.gpt.to(codes.device)
|
||||
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
|
||||
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
|
||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
|
||||
clip_inputs=False)
|
||||
assert latents.shape[1] == codes.shape[1]
|
||||
return {self.output: latents}
|
||||
|
|
Loading…
Reference in New Issue
Block a user