Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
5d5558893a
|
@ -231,7 +231,7 @@ class GptAsrHf(nn.Module):
|
||||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||||
|
|
||||||
|
|
||||||
def get_logits(self, mel_inputs, text_targets, get_attns):
|
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||||
# Pad front and back. Pad at front is the "START" token.
|
# Pad front and back. Pad at front is the "START" token.
|
||||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||||
|
@ -290,25 +290,29 @@ def register_gpt_asr_hf(opt_net, opt):
|
||||||
|
|
||||||
# Quick script that loads a model and halves the number of layers, then saves that model.
|
# Quick script that loads a model and halves the number of layers, then saves that model.
|
||||||
def distill():
|
def distill():
|
||||||
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=512, heads=8)
|
||||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
gpt.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_ema.pth'))
|
||||||
rc = 0
|
rc = 0
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(gpt.gpt.layers.layers):
|
while i < len(gpt.gpt.h):
|
||||||
if rc % 2 != 0:
|
if rc % 2 != 0:
|
||||||
del gpt.gpt.layers.layers[i]
|
del gpt.gpt.h[i]
|
||||||
else:
|
else:
|
||||||
i += 1
|
i += 1
|
||||||
rc += 1
|
rc += 1
|
||||||
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
torch.save(gpt.state_dict(), 'X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\48000_gpt_distilled.pth')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
distill()
|
||||||
|
|
||||||
|
'''
|
||||||
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
gpt = GptAsrHf(max_symbols_per_phrase=250, max_mel_frames=1400, layers=16, model_dim=512, heads=8)
|
||||||
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
#l = gpt(torch.randn(2,80,800), torch.randint(high=len(symbols), size=(2,100)))
|
||||||
start = time()
|
start = time()
|
||||||
gpt.inference(torch.randn(1,80,350), num_beams=1)
|
gpt.inference(torch.randn(1,80,350), num_beams=1)
|
||||||
print(f"Elapsed: {time()-start}")
|
print(f"Elapsed: {time()-start}")
|
||||||
|
'''
|
||||||
|
|
||||||
'''
|
'''
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user