gpt_asr_hf2: remove dual positional embeddings

This commit is contained in:
James Betker 2021-12-28 10:57:45 -07:00
parent 93624fa4b2
commit 312f631c5b

View File

@ -243,7 +243,7 @@ class GptAsrHf2(nn.Module):
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
# Initialize the embeddings per the GPT-2 scheme
for module in [self.text_pos_embedding, self.mel_pos_embedding]:
for module in [self.text_pos_embedding, self.text_solo_pos_embedding, self.mel_pos_embedding]:
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@ -334,8 +334,8 @@ if __name__ == '__main__':
#distill()
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
gpt.text_only(torch.randint(high=len(symbols), size=(2,100)))
l = gpt(torch.randn(2,80,640), torch.randint(high=len(symbols), size=(2,80)))
gpt.text_only(torch.randint(high=len(symbols), size=(2,120)))
#start = time()
#gpt.inference(torch.randn(1,80,350), num_beams=1)