From 312f631c5b1f6c89965aaa6abbf2f48848ee4b14 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 28 Dec 2021 10:57:45 -0700 Subject: [PATCH] gpt_asr_hf2: remove dual positional embeddings --- codes/models/gpt_voice/gpt_asr_hf2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf2.py b/codes/models/gpt_voice/gpt_asr_hf2.py index 6b4a8e2b..46c1d83d 100644 --- a/codes/models/gpt_voice/gpt_asr_hf2.py +++ b/codes/models/gpt_voice/gpt_asr_hf2.py @@ -243,7 +243,7 @@ class GptAsrHf2(nn.Module): self.text_head = nn.Linear(model_dim, self.number_text_tokens) # Initialize the embeddings per the GPT-2 scheme - for module in [self.text_pos_embedding, self.mel_pos_embedding]: + for module in [self.text_pos_embedding, self.text_solo_pos_embedding, self.mel_pos_embedding]: module.weight.data.normal_(mean=0.0, std=self.gpt.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() @@ -334,8 +334,8 @@ if __name__ == '__main__': #distill() gpt = GptAsrHf2(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) - l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) - gpt.text_only(torch.randint(high=len(symbols), size=(2,100))) + l = gpt(torch.randn(2,80,640), torch.randint(high=len(symbols), size=(2,80))) + gpt.text_only(torch.randint(high=len(symbols), size=(2,120))) #start = time() #gpt.inference(torch.randn(1,80,350), num_beams=1)