Merge remote-tracking branch 'origin/master'

This commit is contained in:
James Betker 2021-11-08 20:10:49 -07:00
commit 5d5558893a

View File

@ -231,7 +231,7 @@ class GptAsrHf(nn.Module):
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def get_logits(self, mel_inputs, text_targets, get_attns): def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front and back. Pad at front is the "START" token. # Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
@ -290,25 +290,29 @@ def register_gpt_asr_hf(opt_net, opt):
# Quick script that loads a model and halves the number of layers, then saves that model. # Quick script that loads a model and halves the number of layers, then saves that model.
def distill(): def distill():
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8)
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth')) gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth'))
rc = 0 rc = 0
i = 0 i = 0
while i < len(gpt.gpt.layers.layers): while i < len(gpt.gpt.h):
if rc % 2 != 0: if rc % 2 != 0:
del gpt.gpt.layers.layers[i] del gpt.gpt.h[i]
else: else:
i += 1 i += 1
rc += 1 rc += 1
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth') torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth')
if __name__ == '__main__': if __name__ == '__main__':
distill()
'''
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8) gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100))) #l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
start = time() start = time()
gpt.inference(torch.randn(1,80,350), num_beams=1) gpt.inference(torch.randn(1,80,350), num_beams=1)
print(f"Elapsed: {time()-start}") print(f"Elapsed: {time()-start}")
'''
''' '''
with torch.no_grad(): with torch.no_grad():