Make gpt_asr_hf2 more efficient at inference
This commit is contained in:
parent
5e1d1da2e9
commit
e7a705fe6e
|
@ -224,7 +224,7 @@ class GptAsrHf2(nn.Module):
|
||||||
make its output useful.
|
make its output useful.
|
||||||
"""
|
"""
|
||||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000,
|
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000,
|
||||||
checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0):
|
checkpointing=True, number_text_tokens=512, start_token=511, stop_token=0, mel_compression=256):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number_text_tokens = number_text_tokens
|
self.number_text_tokens = number_text_tokens
|
||||||
self.start_token = start_token
|
self.start_token = start_token
|
||||||
|
@ -233,6 +233,7 @@ class GptAsrHf2(nn.Module):
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.mel_encoder = LeanMelEncoder(model_dim)
|
self.mel_encoder = LeanMelEncoder(model_dim)
|
||||||
self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction
|
self.max_mel_frames = max_mel_frames // self.mel_encoder.reduction
|
||||||
|
self.mel_compression = mel_compression
|
||||||
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
|
seq_length = 2+self.max_symbols_per_phrase+self.max_mel_frames
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
|
self.gpt_config = GPT2Config(vocab_size=self.number_text_tokens,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length,
|
||||||
|
@ -293,7 +294,7 @@ class GptAsrHf2(nn.Module):
|
||||||
text_logits = text_logits.permute(0,2,1)
|
text_logits = text_logits.permute(0,2,1)
|
||||||
return text_logits
|
return text_logits
|
||||||
|
|
||||||
def forward(self, mel_inputs, text_inputs, return_attentions=False):
|
def forward(self, mel_inputs, wav_lengths, text_inputs, text_lengths, return_attentions=False):
|
||||||
"""
|
"""
|
||||||
"Normal" forward pass which produces a text loss when given a MEL-encoded audio clip and transcribed text
|
"Normal" forward pass which produces a text loss when given a MEL-encoded audio clip and transcribed text
|
||||||
targets.
|
targets.
|
||||||
|
@ -301,6 +302,13 @@ class GptAsrHf2(nn.Module):
|
||||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
|
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
|
||||||
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
|
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
|
||||||
|
|
||||||
|
# Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches
|
||||||
|
# which are padded at the macro-batch level.
|
||||||
|
max_text_len = text_lengths.max()
|
||||||
|
text_inputs = text_inputs[:, :max_text_len]
|
||||||
|
max_mel_len = wav_lengths.max() // self.mel_compression
|
||||||
|
mel_inputs = mel_inputs[:, :, :max_mel_len]
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
@ -311,13 +319,18 @@ class GptAsrHf2(nn.Module):
|
||||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||||
return loss_text.mean(), text_logits
|
return loss_text.mean(), text_logits
|
||||||
|
|
||||||
def text_only(self, text_inputs):
|
def text_only(self, text_inputs, text_lengths):
|
||||||
"""
|
"""
|
||||||
Used to train on only text inputs.
|
Used to train on only text inputs.
|
||||||
"""
|
"""
|
||||||
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
|
assert text_inputs.shape[1] <= self.max_symbols_per_phrase, str(text_inputs.shape[1])
|
||||||
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
|
assert text_inputs.max() <= self.number_text_tokens, str(text_inputs.max())
|
||||||
|
|
||||||
|
# Trim off excessive inputs to speed training. This might seem odd, but consider that this model is fed microbatches
|
||||||
|
# which are padded at the macro-batch level.
|
||||||
|
max_text_len = text_lengths.max()
|
||||||
|
text_inputs = text_inputs[:, :max_text_len]
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_token, self.stop_token)
|
||||||
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
text_emb = self.gpt.get_input_embeddings()(text_inputs) + \
|
||||||
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
|
self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + \
|
||||||
|
@ -371,8 +384,8 @@ if __name__ == '__main__':
|
||||||
#distill()
|
#distill()
|
||||||
|
|
||||||
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||||
l = gpt(torch.randn(2,80,640), torch.randint(high=100, size=(2,80)))
|
l = gpt(torch.randn(2,80,640), torch.tensor([100*256,20*256]), torch.randint(high=100, size=(2,80)), torch.tensor([15,60]))
|
||||||
gpt.text_only(torch.randint(high=100, size=(2,120)))
|
gpt.text_only(torch.randint(high=100, size=(2,120)), torch.tensor([30,33]))
|
||||||
|
|
||||||
#start = time()
|
#start = time()
|
||||||
#gpt.inference(torch.randn(1,80,350), num_beams=1)
|
#gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user