From 525addffabbcdae5177c288dcac49d1c201e93cb Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 6 Jan 2022 10:13:45 -0700 Subject: [PATCH] Unified: automatically clip inputs according to specified max length to improve inference time --- codes/models/gpt_voice/unified_voice.py | 43 +++++++++++++++++++------ 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/codes/models/gpt_voice/unified_voice.py b/codes/models/gpt_voice/unified_voice.py index 84809bdc..8b93ef87 100644 --- a/codes/models/gpt_voice/unified_voice.py +++ b/codes/models/gpt_voice/unified_voice.py @@ -115,9 +115,9 @@ class UnifiedGptVoice(nn.Module): self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) - self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 1, model_dim) - self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim) - seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs + self.text_pos_embedding = nn.Embedding(self.max_text_tokens + 2, model_dim) + self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 2, model_dim) + seq_length = 4+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes, n_positions=seq_length, n_ctx=seq_length, @@ -212,13 +212,14 @@ class UnifiedGptVoice(nn.Module): else: return first_logits - def forward(self, speech_conditioning_input, text_inputs, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): + def forward(self, speech_conditioning_input, text_inputs, text_lengths, mel_codes, wav_lengths, text_first=True, raw_mels=None, return_attentions=False): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). speech_conditioning_input: MEL float tensor, (b,80,s) text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) raw_mels: MEL float tensor (b,80,s) @@ -226,7 +227,16 @@ class UnifiedGptVoice(nn.Module): assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len*4] mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + if self.shuffle_conditioning: speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) @@ -235,7 +245,7 @@ class UnifiedGptVoice(nn.Module): text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token) if raw_mels is not None: - mel_inp = F.pad(raw_mels, (0, 4)) + mel_inp = F.pad(raw_mels, (0, 8)) else: mel_inp = mel_codes mel_emb = self.gpt.get_input_embeddings()(mel_inp) @@ -251,13 +261,18 @@ class UnifiedGptVoice(nn.Module): loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) return loss_text.mean(), loss_mel.mean(), mel_logits - def text_forward(self, speech_conditioning_input, text_inputs): + def text_forward(self, speech_conditioning_input, text_inputs, text_lengths): """ Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). """ assert self.max_text_tokens >= text_inputs.shape[1], f'{text_inputs.shape[1]}' + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = F.pad(text_inputs[:, :max_text_len], (0,1), value=self.stop_text_token) + if self.shuffle_conditioning: speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) @@ -274,7 +289,14 @@ class UnifiedGptVoice(nn.Module): """ assert self.max_mel_tokens >= mel_codes.shape[1], f'{mel_codes.shape[1]}' + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = F.pad(mel_codes[:, :max_mel_len], (0,1), value=self.stop_mel_token) mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + if raw_mels is not None: + raw_mels = raw_mels[:, :, :max_mel_len*4] + if self.shuffle_conditioning: speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input) speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1) @@ -294,6 +316,7 @@ class UnifiedGptVoice(nn.Module): if not hasattr(self, 'inference_model'): self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) @@ -319,10 +342,10 @@ def register_unified_gpt_voice(opt_net, opt): if __name__ == '__main__': - gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=False) + gpt = UnifiedGptVoice(model_dim=256, heads=4, train_solo_embeddings=True, use_mel_codes_as_input=True) l = gpt(torch.randn(2, 80, 800), torch.randint(high=len(symbols), size=(2,80)), + torch.tensor([32, 80]), torch.randint(high=8192, size=(2,250)), - torch.tensor([150*256,195*256]), - raw_mels=torch.randn(2,80,1000)) - gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80))) + torch.tensor([150*256,195*256])) + gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))