Check in GPT with new inference methods (but not the backing code..)

This commit is contained in:
James Betker 2021-10-29 17:21:40 -06:00
parent 0822792d79
commit 986fc9628d
2 changed files with 77 additions and 18 deletions

View File

@ -1,3 +1,5 @@
from time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -92,12 +94,12 @@ class GptAsr(nn.Module):
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long()) loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
return loss_text.mean(), text_logits return loss_text.mean(), text_logits
def inference_beam_topk(self, mel): def inference_beam_topk(self, mel, fn='inference_beam'):
def topk_sampler(distribution, k): def topk_sampler(distribution, k):
return torch.topk(distribution, k=k, dim=-1) return torch.topk(distribution, k=k, dim=-1)
return self.inference_beam(mel, topk_sampler) return getattr(self, fn)(mel, topk_sampler)
def inference_beam_sampled(self, mel): def inference_beam_sampled(self, mel, fn='inference_beam'):
def multinomial_sampler(distribution, k): def multinomial_sampler(distribution, k):
indices = torch.multinomial(distribution, num_samples=k, replacement=False) indices = torch.multinomial(distribution, num_samples=k, replacement=False)
values = torch.gather(distribution, dim=1, index=indices) values = torch.gather(distribution, dim=1, index=indices)
@ -106,7 +108,7 @@ class GptAsr(nn.Module):
self.indices = i self.indices = i
self.values = v self.values = v
return container(indices, values) return container(indices, values)
return self.inference_beam(mel, multinomial_sampler) return getattr(self, fn)(mel, multinomial_sampler)
def inference_beam(self, mel_inputs, sampler_fn): def inference_beam(self, mel_inputs, sampler_fn):
beam_width = 16 beam_width = 16
@ -151,33 +153,90 @@ class GptAsr(nn.Module):
return text_seq return text_seq
def inference_beam_opt(self, mel_inputs, sampler_fn):
beam_width = 16
temperature = .8
b, _, s = mel_inputs.shape
assert b == 1 # Beam search only works on batches of one.
mel_emb = self.mel_encoder(mel_inputs)
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
intermediates = []
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
probabilities = torch.ones((b,), device=mel_emb.device)
while text_seq.shape[-1] < self.max_symbols_per_phrase:
text_emb = self.text_embedding(text_seq)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
if text_emb.shape[0] != mel_emb.shape[0]:
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
emb = torch.cat([mel_emb, text_emb], dim=1)
if len(intermediates) == 0:
enc, intermediates = self.gpt(emb, return_intermediates=True)
intermediates = [(i[0].repeat(beam_width, 1, 1),
i[1].repeat(beam_width, 1, 1)) for i in intermediates]
else:
enc, intermediates = self.gpt.infer_last_two(emb, intermediates)
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
text_logits = self.text_head(text_logits)
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
probabilities, sort_indices = torch.sort(probabilities, descending=True)
probabilities = probabilities[:beam_width]
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
codes = topk.indices.flatten()
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
text_seq = text_seq[sort_indices]
text_seq = text_seq[:beam_width]
# PAD doubles as a stop token. PAD=0.
if torch.all(torch.any(text_seq == 0, dim=1)):
break
if text_seq.shape[1] >= self.max_mel_frames:
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
return text_seq
@register_model @register_model
def register_gpt_asr(opt_net, opt): def register_gpt_asr(opt_net, opt):
return GptAsr(**opt_get(opt_net, ['kwargs'], {})) return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
# Halves the number of layers in the provided model. # Quick script that loads a model and halves the number of layers, then saves that model.
def distill(model): def distill():
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
rc = 0 rc = 0
i = 0 i = 0
while i < len(model.gpt.layers.layers): while i < len(gpt.gpt.layers.layers):
if rc % 2 != 0: if rc % 2 != 0:
del model.gpt.layers.layers[i] del gpt.gpt.layers.layers[i]
else: else:
i += 1 i += 1
rc += 1 rc += 1
return model torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
if __name__ == '__main__': if __name__ == '__main__':
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12) gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda()
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
student = distill(gpt)
torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
#l = gpt(torch.randn(2,80,800), #l = gpt(torch.randn(2,80,800),
# torch.randint(high=len(symbols), size=(2,180))) # torch.randint(high=len(symbols), size=(2,180)))
#o = gpt.infer(torch.randint(high=24, size=(2,60))) with torch.no_grad():
#print(o.shape) t = torch.randn(1,80,800).cuda()
start = time()
s = gpt.inference_beam_topk(t)
print(time()-start)
start = time()
o = gpt.inference_beam_topk(t, fn='inference_beam_opt')
print(time()-start)

View File

@ -41,7 +41,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_metrics = False want_metrics = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_deepspeech_libri.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt
@ -71,8 +71,8 @@ if __name__ == "__main__":
tq = tqdm(test_loader) tq = tqdm(test_loader)
for data in tq: for data in tq:
#if data['clips'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255: if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
# continue continue
pred = forward_pass(model, data, dataset_dir, opt, batch) pred = forward_pass(model, data, dataset_dir, opt, batch)
pred = pred.replace('_', '') pred = pred.replace('_', '')
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n') output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')