forked from mrq/DL-Art-School
Fixes for training the diffusion model on autoregressive inputs
This commit is contained in:
parent
a3622462c1
commit
8ea5c307fb
|
@ -353,7 +353,7 @@ class UnifiedVoice(nn.Module):
|
||||||
enc = self.final_norm(enc)
|
enc = self.final_norm(enc)
|
||||||
|
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
return enc[:, speech_conditioning_inputs.shape[1]:speech_conditioning_inputs.shape[1]+first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
||||||
|
|
||||||
first_logits = enc[:, :first_inputs.shape[1]]
|
first_logits = enc[:, :first_inputs.shape[1]]
|
||||||
first_logits = first_head(first_logits)
|
first_logits = first_head(first_logits)
|
||||||
|
@ -367,7 +367,7 @@ class UnifiedVoice(nn.Module):
|
||||||
return first_logits
|
return first_logits
|
||||||
|
|
||||||
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False,
|
||||||
return_latent=False):
|
return_latent=False, clip_inputs=True):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
(actuated by `text_first`).
|
(actuated by `text_first`).
|
||||||
|
@ -381,16 +381,20 @@ class UnifiedVoice(nn.Module):
|
||||||
|
|
||||||
If return_attentions is specified, only logits are returned.
|
If return_attentions is specified, only logits are returned.
|
||||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||||
|
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
||||||
"""
|
"""
|
||||||
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
if clip_inputs:
|
||||||
# chopping the inputs by the maximum actual length.
|
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
||||||
max_text_len = text_lengths.max()
|
# chopping the inputs by the maximum actual length.
|
||||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token)
|
max_text_len = text_lengths.max()
|
||||||
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
text_inputs = text_inputs[:, :max_text_len]
|
||||||
mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token)
|
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
||||||
if raw_mels is not None:
|
mel_codes = mel_codes[:, :max_mel_len]
|
||||||
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
if raw_mels is not None:
|
||||||
|
raw_mels = raw_mels[:, :, :max_mel_len*4]
|
||||||
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
|
||||||
|
text_inputs = F.pad(text_inputs, (0,1), value=self.stop_text_token)
|
||||||
|
mel_codes = F.pad(mel_codes, (0,1), value=self.stop_mel_token)
|
||||||
|
|
||||||
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
speech_conditioning_input = speech_conditioning_input.unsqueeze(1) if len(speech_conditioning_input.shape) == 3 else speech_conditioning_input
|
||||||
conds = []
|
conds = []
|
||||||
|
@ -413,11 +417,11 @@ class UnifiedVoice(nn.Module):
|
||||||
if text_first:
|
if text_first:
|
||||||
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return mel_logits[:, :-1] # Despite the name, these are not logits.
|
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||||
else:
|
else:
|
||||||
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
||||||
if return_latent:
|
if return_latent:
|
||||||
return text_logits[:, :-1] # Despite the name, these are not logits
|
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
||||||
|
|
||||||
if return_attentions:
|
if return_attentions:
|
||||||
return mel_logits
|
return mel_logits
|
||||||
|
|
|
@ -159,6 +159,7 @@ class GptVoiceLatentInjector(Injector):
|
||||||
pretrained_path = opt['gpt_path']
|
pretrained_path = opt['gpt_path']
|
||||||
self.gpt = load_model_from_config(cfg, model_name=model_name,
|
self.gpt = load_model_from_config(cfg, model_name=model_name,
|
||||||
also_load_savepoint=False, load_path=pretrained_path).cuda().eval()
|
also_load_savepoint=False, load_path=pretrained_path).cuda().eval()
|
||||||
|
self.needs_move = True
|
||||||
# Mel converter
|
# Mel converter
|
||||||
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
self.mel_inj = TorchMelSpectrogramInjector({'in': 'wav', 'out': 'mel', 'mel_norm_file': '../experiments/clips_mel_norms.pth'},{})
|
||||||
# Aux input keys.
|
# Aux input keys.
|
||||||
|
@ -179,10 +180,13 @@ class GptVoiceLatentInjector(Injector):
|
||||||
mel_conds.append(self.to_mel(state_cond[:, k]))
|
mel_conds.append(self.to_mel(state_cond[:, k]))
|
||||||
mel_conds = torch.stack(mel_conds, dim=1)
|
mel_conds = torch.stack(mel_conds, dim=1)
|
||||||
|
|
||||||
self.dvae = self.dvae.to(mel_inputs.device)
|
if self.needs_move:
|
||||||
|
self.dvae = self.dvae.to(mel_inputs.device)
|
||||||
|
self.gpt = self.gpt.to(mel_inputs.device)
|
||||||
codes = self.dvae.get_codebook_indices(mel_inputs)
|
codes = self.dvae.get_codebook_indices(mel_inputs)
|
||||||
self.gpt = self.gpt.to(codes.device)
|
|
||||||
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
|
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
|
||||||
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
||||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True)
|
text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
|
||||||
|
clip_inputs=False)
|
||||||
|
assert latents.shape[1] == codes.shape[1]
|
||||||
return {self.output: latents}
|
return {self.output: latents}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user