From c584320cf30da8ad10e51227aba5698684275dcf Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 7 Nov 2021 21:53:21 -0700 Subject: [PATCH] Fix gpt_asr_hf distillation --- codes/models/gpt_voice/gpt_asr_hf.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py index 4439a1d4..67e509f4 100644 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ b/codes/models/gpt_voice/gpt_asr_hf.py @@ -231,7 +231,7 @@ class GptAsrHf(nn.Module): self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - def get_logits(self, mel_inputs, text_targets, get_attns): + def get_logits(self, mel_inputs, text_targets, get_attns=False): # Pad front and back. Pad at front is the "START" token. text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) @@ -290,25 +290,29 @@ def register_gpt_asr_hf(opt_net, opt): # Quick script that loads a model and halves the number of layers, then saves that model. def distill(): - gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) - gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth')) + gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8) + gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth')) rc = 0 i = 0 - while i < len(gpt.gpt.layers.layers): + while i < len(gpt.gpt.h): if rc % 2 != 0: - del gpt.gpt.layers.layers[i] + del gpt.gpt.h[i] else: i += 1 rc += 1 - torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth') + torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth') if __name__ == '__main__': + distill() + + ''' gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) #l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) start = time() gpt.inference(torch.randn(1,80,350), num_beams=1) print(f"Elapsed: {time()-start}") + ''' ''' with torch.no_grad():