gpt_asr_hf2: remove dual positional embeddings

This commit is contained in:
James Betker 2021-12-28 10:57:45 -07:00
parent 93624fa4b2
commit 312f631c5b

View File

@ -243,7 +243,7 @@ class GptAsrHf2(nn.Module):
self.text_head = nn.Linear(model_dim, self.number_text_tokens) self.text_head = nn.Linear(model_dim, self.number_text_tokens)
# Initialize the embeddings per the GPT-2 scheme # Initialize the embeddings per the GPT-2 scheme
for module in [self.text_pos_embedding, self.mel_pos_embedding]: for module in [self.text_pos_embedding, self.text_solo_pos_embedding, self.mel_pos_embedding]:
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
@ -334,8 +334,8 @@ if __name__ == '__main__':
#distill() #distill()
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) l = gpt(torch.randn(2,80,640), torch.randint(high=len(symbols), size=(2,80)))
gpt.text_only(torch.randint(high=len(symbols), size=(2,100))) gpt.text_only(torch.randint(high=len(symbols), size=(2,120)))
#start = time() #start = time()
#gpt.inference(torch.randn(1,80,350), num_beams=1) #gpt.inference(torch.randn(1,80,350), num_beams=1)