From e7a705fe6e3b1a2919d435161f188b879fdf6cc1 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 6 Jan 2022 10:27:10 -0700 Subject: [PATCH] Make gpt_asr_hf2 more efficient at inference --- codes/models/gpt_voice/gpt_asr_hf2.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index dd21ea9c..2ecc26da 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -224,7 +224,7 @@ class GptAsrHf2(nn.Module): make its output useful. """ def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, - checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0): + checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0, mel_compression=256): super().__init__() self.number_text_tokens = number_text_tokens self.start_token = start_token @@ -233,6 +233,7 @@ class GptAsrHf2(nn.Module): self.model_dim = model_dim self.mel_encoder = LeanMelEncoder(model_dim) self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction + self.mel_compression = mel_compression seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens, n_positions=seq_length, @@ -293,7 +294,7 @@ class GptAsrHf2(nn.Module): text_logits = text_logits.permute(0,2,1) return text_logits - def forward(self, mel_inputs, text_inputs, return_attentions=False): + def forward(self, mel_inputs, wav_lengths, text_inputs, text_lengths, return_attentions=False): """ "Normal" forward pass which produces a text loss when given a MEL-encoded audio clip and transcribed text targets. @@ -301,6 +302,13 @@ class GptAsrHf2(nn.Module): assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1]) assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max()) + # Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches + # which are padded at the macro-batch level. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_compression + mel_inputs = mel_inputs[:, :, :max_mel_len] + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) @@ -311,13 +319,18 @@ class GptAsrHf2(nn.Module): loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits - def text_only(self, text_inputs): + def text_only(self, text_inputs, text_lengths): """ Used to train on only text inputs. """ assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1]) assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max()) + # Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches + # which are padded at the macro-batch level. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token) text_emb = self.gpt.get_input_embeddings()(text_inputs) + \ self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \ @@ -371,8 +384,8 @@ if __name__ == '__main__': #distill() gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) - l = gpt(torch.randn(2,80,640), torch.randint(high=100, size=(2,80))) - gpt.text_only(torch.randint(high=100, size=(2,120))) + l = gpt(torch.randn(2,80,640), torch.tensor([100*256,20*256]), torch.randint(high=100, size=(2,80)), torch.tensor([15,60])) + gpt.text_only(torch.randint(high=100, size=(2,120)), torch.tensor([30,33])) #start = time() #gpt.inference(torch.randn(1,80,350), num_beams=1)