Fix gpt_asr_hf distillation
This commit is contained in:
parent
9b3c3b1227
commit
c584320cf3
|
@ -231,7 +231,7 @@ class GptAsrHf(nn.Module):
|
|||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns):
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||
# Pad front and back. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||
|
@ -290,25 +290,29 @@ def register_gpt_asr_hf(opt_net, opt):
|
|||
|
||||
# Quick script that loads a model and halves the number of layers, then saves that model.
|
||||
def distill():
|
||||
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8)
|
||||
gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth'))
|
||||
rc = 0
|
||||
i = 0
|
||||
while i < len(gpt.gpt.layers.layers):
|
||||
while i < len(gpt.gpt.h):
|
||||
if rc % 2 != 0:
|
||||
del gpt.gpt.layers.layers[i]
|
||||
del gpt.gpt.h[i]
|
||||
else:
|
||||
i += 1
|
||||
rc += 1
|
||||
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||
torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
distill()
|
||||
|
||||
'''
|
||||
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||
start = time()
|
||||
gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||
print(f"Elapsed: {time()-start}")
|
||||
'''
|
||||
|
||||
'''
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in New Issue
Block a user