Make gpt_asr_hf2 more efficient at inference

This commit is contained in:
James Betker 2022-01-06 10:27:10 -07:00
parent 5e1d1da2e9
commit e7a705fe6e

View File

@ -224,7 +224,7 @@ class GptAsrHf2(nn.Module):
make its output useful.
"""
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000,
checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0):
checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0, mel_compression=256):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_token = start_token
@ -233,6 +233,7 @@ class GptAsrHf2(nn.Module):
self.model_dim = model_dim
self.mel_encoder = LeanMelEncoder(model_dim)
self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction
self.mel_compression = mel_compression
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
n_positions=seq_length,
@ -293,7 +294,7 @@ class GptAsrHf2(nn.Module):
text_logits = text_logits.permute(0,2,1)
return text_logits
def forward(self, mel_inputs, text_inputs, return_attentions=False):
def forward(self, mel_inputs, wav_lengths, text_inputs, text_lengths, return_attentions=False):
"""
"Normal" forward pass which produces a text loss when given a MEL-encoded audio clip and transcribed text
targets.
@ -301,6 +302,13 @@ class GptAsrHf2(nn.Module):
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
# Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches
# which are padded at the macro-batch level.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
max_mel_len = wav_lengths.max() // self.mel_compression
mel_inputs = mel_inputs[:, :, :max_mel_len]
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
@ -311,13 +319,18 @@ class GptAsrHf2(nn.Module):
loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean(), text_logits
def text_only(self, text_inputs):
def text_only(self, text_inputs, text_lengths):
"""
Used to train on only text inputs.
"""
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
# Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches
# which are padded at the macro-batch level.
max_text_len = text_lengths.max()
text_inputs = text_inputs[:, :max_text_len]
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
@ -371,8 +384,8 @@ if __name__ == '__main__':
#distill()
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
l = gpt(torch.randn(2,80,640), torch.randint(high=100, size=(2,80)))
gpt.text_only(torch.randint(high=100, size=(2,120)))
l = gpt(torch.randn(2,80,640), torch.tensor([100*256,20*256]), torch.randint(high=100, size=(2,80)), torch.tensor([15,60]))
gpt.text_only(torch.randint(high=100, size=(2,120)), torch.tensor([30,33]))
#start = time()
#gpt.inference(torch.randn(1,80,350), num_beams=1)