Fix training flow for NEXT TOKEN prediction instead of same token prediction
doh
This commit is contained in:
parent
d9936df363
commit
36c7c1fbdb
|
@ -60,20 +60,28 @@ class GptTts(nn.Module):
|
|||
mel_logits = self.mel_head(mel_logits)
|
||||
|
||||
# Compute loss
|
||||
loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
|
||||
loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
|
||||
text_targets = text_inputs[:,1:]
|
||||
text_logits = text_logits.permute(0,2,1)[:,:,:-1] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
|
||||
loss_text = F.cross_entropy(text_logits, text_targets, reduction='none')
|
||||
mel_targets = mel_targets[:,1:]
|
||||
mel_logits = mel_logits.permute(0,2,1)[:,:,:-1]
|
||||
loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none')
|
||||
|
||||
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
|
||||
pad_loss_reduction_factor = .01
|
||||
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
|
||||
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
|
||||
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
|
||||
text_pad_mask = ~get_mask_from_lengths(text_lengths-1, text_inputs.shape[1]-1) # -1 to strip off <BOS>, which is accounted for in text_lengths and output_lengths.
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
|
||||
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask, pad_loss_reduction_factor)
|
||||
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask, pad_loss_reduction_factor)
|
||||
|
||||
# Fix up mel_logits so it can go into a VAE decoder as well.
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_codes = mel_codes[:,1:-1] # Strip off first and last tokens (START+STOP were added by the dataloader)
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:-1], 0)
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1)
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0)
|
||||
mel_codes = mel_codes[:,:
|
||||
|
||||
|
||||
|
||||
-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
|
||||
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
|
||||
mel_codes = mel_codes * extra_mask
|
||||
|
||||
|
@ -85,19 +93,21 @@ class GptTts(nn.Module):
|
|||
|
||||
mel_seq = [self.MEL_START_TOKEN, 0]
|
||||
while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames:
|
||||
mel_emb = self.mel_embedding(LongTensor(mel_seq, device=text_inputs.device))
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_seq.shape[1], device=mel_seq.device))
|
||||
mel_seq.append(0)
|
||||
mel_emb = self.mel_embedding(torch.tensor(mel_seq, dtype=torch.long, device=text_inputs.device)).unsqueeze(0)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_seq[-1] = mel_codes[-1]
|
||||
mel_seq.append(0)
|
||||
|
||||
if len(mel_seq) >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
|
||||
|
||||
# Prevent sending invalid tokens to the VAE
|
||||
mel_seq = [s if s < 512 else 0 for s in mel_seq]
|
||||
return mel_seq[:-1]
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
Loading…
Reference in New Issue
Block a user