gpt_asr_hf2: remove dual positional embeddings
This commit is contained in:
parent
93624fa4b2
commit
312f631c5b
|
@ -243,7 +243,7 @@ class GptAsrHf2(nn.Module):
|
|||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
|
||||
# Initialize the embeddings per the GPT-2 scheme
|
||||
for module in [self.text_pos_embedding, self.mel_pos_embedding]:
|
||||
for module in [self.text_pos_embedding, self.text_solo_pos_embedding, self.mel_pos_embedding]:
|
||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
@ -334,8 +334,8 @@ if __name__ == '__main__':
|
|||
#distill()
|
||||
|
||||
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||
gpt.text_only(torch.randint(high=len(symbols), size=(2,100)))
|
||||
l = gpt(torch.randn(2,80,640), torch.randint(high=len(symbols), size=(2,80)))
|
||||
gpt.text_only(torch.randint(high=len(symbols), size=(2,120)))
|
||||
|
||||
#start = time()
|
||||
#gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user