Fix training flow for NEXT TOKEN prediction instead of same token prediction

doh
This commit is contained in:
James Betker 2021-08-04 10:28:09 -06:00
parent d9936df363
commit 36c7c1fbdb
2 changed files with 23 additions and 13 deletions

View File

@ -60,20 +60,28 @@ class GptTts(nn.Module):
mel_logits = self.mel_head(mel_logits)
# Compute loss
loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
text_targets = text_inputs[:,1:]
text_logits = text_logits.permute(0,2,1)[:,:,:-1] # The last element of the logits is unneeded because the input to the transformer contains a <EOS> token for both text and mel.
loss_text = F.cross_entropy(text_logits, text_targets, reduction='none')
mel_targets = mel_targets[:,1:]
mel_logits = mel_logits.permute(0,2,1)[:,:,:-1]
loss_mel = F.cross_entropy(mel_logits, mel_targets, reduction='none')
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
pad_loss_reduction_factor = .01
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
text_pad_mask = ~get_mask_from_lengths(text_lengths-1, text_inputs.shape[1]-1) # -1 to strip off <BOS>, which is accounted for in text_lengths and output_lengths.
mel_pad_mask = ~get_mask_from_lengths(output_lengths-1, mel_targets.shape[1])
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask, pad_loss_reduction_factor)
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask, pad_loss_reduction_factor)
# Fix up mel_logits so it can go into a VAE decoder as well.
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_codes = mel_codes[:,1:-1] # Strip off first and last tokens (START+STOP were added by the dataloader)
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:-1], 0)
mel_codes = torch.argmax(F.softmax(mel_logits, dim=1), dim=1)
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask, 0)
mel_codes = mel_codes[:,:
-1] # Strip off <EOS> token too (or padding). The important part is that the output sequence length is identical to the VAE input.
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
@ -85,19 +93,21 @@ class GptTts(nn.Module):
mel_seq = [self.MEL_START_TOKEN, 0]
while mel_seq[-1] != self.MEL_STOP_TOKEN and len(mel_seq) < self.max_mel_frames:
mel_emb = self.mel_embedding(LongTensor(mel_seq, device=text_inputs.device))
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_seq.shape[1], device=mel_seq.device))
mel_seq.append(0)
mel_emb = self.mel_embedding(torch.tensor(mel_seq, dtype=torch.long, device=text_inputs.device)).unsqueeze(0)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_seq[-1] = mel_codes[-1]
mel_seq.append(0)
if len(mel_seq) >= self.max_mel_frames:
print("Warning! Encountered frame limit before a stop token. Output is likely wrong.")
# Prevent sending invalid tokens to the VAE
mel_seq = [s if s < 512 else 0 for s in mel_seq]
return mel_seq[:-1]

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt