gpt_asr_hf2: remove dual positional embeddings
This commit is contained in:
parent
93624fa4b2
commit
312f631c5b
|
@ -243,7 +243,7 @@ class GptAsrHf2(nn.Module):
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
|
|
||||||
# Initialize the embeddings per the GPT-2 scheme
|
# Initialize the embeddings per the GPT-2 scheme
|
||||||
for module in [self.text_pos_embedding, self.mel_pos_embedding]:
|
for module in [self.text_pos_embedding, self.text_solo_pos_embedding, self.mel_pos_embedding]:
|
||||||
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range)
|
||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
@ -334,8 +334,8 @@ if __name__ == '__main__':
|
||||||
#distill()
|
#distill()
|
||||||
|
|
||||||
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||||
l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
l = gpt(torch.randn(2,80,640), torch.randint(high=len(symbols), size=(2,80)))
|
||||||
gpt.text_only(torch.randint(high=len(symbols), size=(2,100)))
|
gpt.text_only(torch.randint(high=len(symbols), size=(2,120)))
|
||||||
|
|
||||||
#start = time()
|
#start = time()
|
||||||
#gpt.inference(torch.randn(1,80,350), num_beams=1)
|
#gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user