Fix gpt_asr_hf distillation

This commit is contained in:
James Betker 2021-11-07 21:53:21 -07:00
parent 9b3c3b1227
commit c584320cf3

View File

@ -231,7 +231,7 @@ class GptAsrHf(nn.Module):
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def get_logits(self, mel_inputs, text_targets, get_attns):
def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
@ -290,25 +290,29 @@ def register_gpt_asr_hf(opt_net, opt):
# Quick script that loads a model and halves the number of layers, then saves that model.
def distill():
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8)
gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth'))
rc = 0
i = 0
while i < len(gpt.gpt.layers.layers):
while i < len(gpt.gpt.h):
if rc % 2 != 0:
del gpt.gpt.layers.layers[i]
del gpt.gpt.h[i]
else:
i += 1
rc += 1
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth')
if __name__ == '__main__':
distill()
'''
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
start = time()
gpt.inference(torch.randn(1,80,350), num_beams=1)
print(f"Elapsed: {time()-start}")
'''
'''
with torch.no_grad():