diff --git a/codes/data/audio/grand_conjoined_dataset.py b/codes/data/audio/grand_conjoined_dataset.py index 7ba5bdb7..aa30f24f 100644 --- a/codes/data/audio/grand_conjoined_dataset.py +++ b/codes/data/audio/grand_conjoined_dataset.py @@ -99,7 +99,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'paired_text_tokens': snt['padded_text'], 'paired_file': snt['filenames'], 'speech_audio': snt['wav'], - 'speech_lengths': snt['wav_lengths'], + 'speech_audio_lengths': snt['wav_lengths'], 'speech_file': snt['filenames'], 'text_text': snt['real_text'], 'text_tokens': snt['padded_text'], @@ -114,7 +114,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'paired_text_tokens': snt['padded_text'], 'paired_file': snt['filenames'], 'speech_audio': sp['clip'], - 'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length), + 'speech_audio_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length), 'speech_file': sp['path'], 'text_text': txt, 'text_tokens': txt_tok, diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 54bf6ffe..562f1d80 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -38,7 +38,15 @@ class ConditioningEncoder(nn.Module): return h[:, :, 0] -class GptTtsHf(nn.Module): +class UnifiedGptVoice(nn.Module): + """ + Derived from GptTtsHf, but offers multiple modes of operation: + - Text only + - Voice only + - Text conditioned on voice + - Voice conditioned on text + """ + NUMBER_TEXT_TOKENS = 10000 # The number of tokens produced by our bespoke BPE tokenizer. START_TEXT_TOKEN = 9999 STOP_TEXT_TOKEN = 0 @@ -79,87 +87,120 @@ class GptTtsHf(nn.Module): tar = F.pad(input, (0,1), value=stop_token) return inp, tar - def get_logits(self, text_inputs, cond_input, mel_inputs, get_attns=False): - text_emb = self.text_embedding(text_inputs) - cond = self.conditioning_encoder(cond_input).unsqueeze(1) - mel_emb = self.gpt.get_input_embeddings()(mel_inputs) - - emb = torch.cat([text_emb, cond, mel_emb], dim=1) - gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) - if get_attns: - return gpt_out.attentions - enc = gpt_out.last_hidden_state - - text_logits = self.final_norm(enc[:, :text_emb.shape[1]]) - text_logits = self.text_head(text_logits) - text_logits = text_logits.permute(0,2,1) - mel_logits = self.final_norm(enc[:, -mel_emb.shape[1]:]) - mel_logits = self.mel_head(mel_logits) - mel_logits = mel_logits.permute(0,2,1) - - return text_logits, mel_logits - - def forward(self, text_inputs, cond_input, mel_targets, wav_lengths, return_attentions=False): + def set_mel_padding(self, mel_input_tokens, wav_lengths): """ - Forward pass - text_inputs: long tensor, (b,t) - cond_inputs: MEL float tensor, (b,c,80,s) - mel_targets: long tensor, (b,m) - mel_lengths: long tensor, (b,) + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. """ # Set padding areas within MEL (currently it is coded with the MEL code for ). mel_lengths = wav_lengths // self.mel_length_compression for b in range(len(mel_lengths)): - if mel_lengths[b] < mel_targets.shape[-1]: - mel_targets[b, mel_lengths[b]:] = self.STOP_MEL_TOKEN + actual_end = mel_lengths[b] + 1 # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.STOP_MEL_TOKEN + return mel_input_tokens - # Randomly permute the conditioning spectrogram, to destroy any structure present. - cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])] + def randomly_permute_conditioning_input(self, speech_conditioning_input): + """ + Randomly permute the conditioning spectrogram, to destroy any structure present. Note that since the + conditioning input is derived from a discrete spectrogram, it does actually retain structure, but only a little + bit (actually: exactly how much we want; enough to discriminate different vocal qualities, but nothing about + what is being said). + """ + cond_input = speech_conditioning_input[:,:,torch.randperm(speech_conditioning_input.shape[-1])] if cond_input.shape[-1] > self.max_conditioning_length: cond_input = cond_input[:,:,:self.max_conditioning_length] + return cond_input + + def get_logits(self, speech_conditioning_input, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_input, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_input, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + first_logits = self.final_norm(enc[:, :first_inputs.shape[1]]) + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0,2,1) + if second_inputs is not None: + second_logits = self.final_norm(enc[:, -second_inputs.shape[1]:]) + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0,2,1) + return first_logits, second_logits + else: + return first_logits + + def forward(self, speech_conditioning_input, text_inputs, mel_inputs, wav_lengths, text_first=True, return_attentions=False): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,80,s) + text_inputs: long tensor, (b,t) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + """ + mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) - mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_targets, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) - text_logits, mel_logits = self.get_logits(text_inputs, cond_input, mel_inputs, get_attns=return_attentions) + text_emb = self.text_embedding(text_inputs) + mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) + mel_emb = self.gpt.get_input_embeddings()(mel_inputs) + if text_first: + text_logits, mel_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions) + else: + mel_logits, text_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions) + if return_attentions: return mel_logits loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - def inference(self, text_inputs, cond_input, **hf_generate_kwargs): - if not hasattr(self, 'inference_model'): - self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head) + def text_forward(self, speech_conditioning_input, text_inputs): + """ + Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the + model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). + """ + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN) text_emb = self.text_embedding(text_inputs) + text_logits = self.get_logits(speech_conditioning_input, text_emb, self.text_head) + loss_text = F.cross_entropy(text_logits, text_targets.long()) + return loss_text.mean() - # Randomly permute the conditioning spectrogram, to destroy any structure present. - cond_input = cond_input[:,:,torch.randperm(cond_input.shape[-1])] - if cond_input.shape[-1] > self.max_conditioning_length: - cond_input = cond_input[:,:,:self.max_conditioning_length] - cond = self.conditioning_encoder(cond_input).unsqueeze(1) + def speech_forward(self, speech_conditioning_input, mel_inputs, wav_lengths): + """ + Performs autoregressive modeling on only speech data. + """ + mel_inputs = self.set_mel_padding(mel_inputs, wav_lengths) + speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) + speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) - emb = torch.cat([text_emb, cond], dim=1) - self.inference_model.store_mel_emb(emb) - - fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device) - fake_inputs[:,-1] = self.START_MEL_TOKEN - - gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN, - max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs) - return gen[:, fake_inputs.shape[1]:] + mel_inputs, mel_targets = self.build_aligned_inputs_and_targets(mel_inputs, self.START_MEL_TOKEN, self.STOP_MEL_TOKEN) + mel_emb = self.gpt.get_input_embeddings()(mel_inputs) + mel_logits = self.get_logits(speech_conditioning_input, mel_emb, self.mel_head) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_mel.mean() @register_model -def register_gpt_tts_hf(opt_net, opt): - return GptTtsHf(**opt_get(opt_net, ['kwargs'], {})) +def register_unified_gpt_voice(opt_net, opt): + return UnifiedGptVoice(**opt_get(opt_net, ['kwargs'], {})) if __name__ == '__main__': - gpt = GptTtsHf(model_dim=1024, heads=16) - l = gpt(torch.randint(high=len(symbols), size=(2,200)), - torch.arange(0, 80, 1, dtype=torch.float).view(1,80,1).repeat(2,1,800), + gpt = UnifiedGptVoice(model_dim=256, heads=4) + l = gpt(torch.randn(2, 80, 800), + torch.randint(high=len(symbols), size=(2,80)), torch.randint(high=8192, size=(2,250)), torch.tensor([150*256,195*256])) diff --git a/codes/train.py b/codes/train.py index 45c473cb..ed4384cc 100644 --- a/codes/train.py +++ b/codes/train.py @@ -286,7 +286,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_unified_voice.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()