Check in GPT with new inference methods (but not the backing code..)
This commit is contained in:
parent
0822792d79
commit
986fc9628d
|
@ -1,3 +1,5 @@
|
|||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -92,12 +94,12 @@ class GptAsr(nn.Module):
|
|||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
def inference_beam_topk(self, mel):
|
||||
def inference_beam_topk(self, mel, fn='inference_beam'):
|
||||
def topk_sampler(distribution, k):
|
||||
return torch.topk(distribution, k=k, dim=-1)
|
||||
return self.inference_beam(mel, topk_sampler)
|
||||
return getattr(self, fn)(mel, topk_sampler)
|
||||
|
||||
def inference_beam_sampled(self, mel):
|
||||
def inference_beam_sampled(self, mel, fn='inference_beam'):
|
||||
def multinomial_sampler(distribution, k):
|
||||
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
|
||||
values = torch.gather(distribution, dim=1, index=indices)
|
||||
|
@ -106,7 +108,7 @@ class GptAsr(nn.Module):
|
|||
self.indices = i
|
||||
self.values = v
|
||||
return container(indices, values)
|
||||
return self.inference_beam(mel, multinomial_sampler)
|
||||
return getattr(self, fn)(mel, multinomial_sampler)
|
||||
|
||||
def inference_beam(self, mel_inputs, sampler_fn):
|
||||
beam_width = 16
|
||||
|
@ -151,33 +153,90 @@ class GptAsr(nn.Module):
|
|||
return text_seq
|
||||
|
||||
|
||||
def inference_beam_opt(self, mel_inputs, sampler_fn):
|
||||
beam_width = 16
|
||||
temperature = .8
|
||||
|
||||
b, _, s = mel_inputs.shape
|
||||
assert b == 1 # Beam search only works on batches of one.
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
|
||||
intermediates = []
|
||||
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||
while text_seq.shape[-1] < self.max_symbols_per_phrase:
|
||||
text_emb = self.text_embedding(text_seq)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
||||
if len(intermediates) == 0:
|
||||
enc, intermediates = self.gpt(emb, return_intermediates=True)
|
||||
intermediates = [(i[0].repeat(beam_width, 1, 1),
|
||||
i[1].repeat(beam_width, 1, 1)) for i in intermediates]
|
||||
else:
|
||||
enc, intermediates = self.gpt.infer_last_two(emb, intermediates)
|
||||
|
||||
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
|
||||
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||
probabilities = probabilities[:beam_width]
|
||||
|
||||
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
|
||||
codes = topk.indices.flatten()
|
||||
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
|
||||
text_seq = text_seq[sort_indices]
|
||||
text_seq = text_seq[:beam_width]
|
||||
|
||||
# PAD doubles as a stop token. PAD=0.
|
||||
if torch.all(torch.any(text_seq == 0, dim=1)):
|
||||
break
|
||||
|
||||
if text_seq.shape[1] >= self.max_mel_frames:
|
||||
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
|
||||
|
||||
return text_seq
|
||||
|
||||
|
||||
@register_model
|
||||
def register_gpt_asr(opt_net, opt):
|
||||
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
# Halves the number of layers in the provided model.
|
||||
def distill(model):
|
||||
# Quick script that loads a model and halves the number of layers, then saves that model.
|
||||
def distill():
|
||||
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||
rc = 0
|
||||
i = 0
|
||||
while i < len(model.gpt.layers.layers):
|
||||
while i < len(gpt.gpt.layers.layers):
|
||||
if rc % 2 != 0:
|
||||
del model.gpt.layers.layers[i]
|
||||
del gpt.gpt.layers.layers[i]
|
||||
else:
|
||||
i += 1
|
||||
rc += 1
|
||||
return model
|
||||
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||
student = distill(gpt)
|
||||
torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||
gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda()
|
||||
#l = gpt(torch.randn(2,80,800),
|
||||
# torch.randint(high=len(symbols), size=(2,180)))
|
||||
|
||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
#print(o.shape)
|
||||
with torch.no_grad():
|
||||
t = torch.randn(1,80,800).cuda()
|
||||
start = time()
|
||||
s = gpt.inference_beam_topk(t)
|
||||
print(time()-start)
|
||||
|
||||
start = time()
|
||||
o = gpt.inference_beam_topk(t, fn='inference_beam_opt')
|
||||
print(time()-start)
|
||||
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_deepspeech_libri.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
@ -71,8 +71,8 @@ if __name__ == "__main__":
|
|||
|
||||
tq = tqdm(test_loader)
|
||||
for data in tq:
|
||||
#if data['clips'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||
# continue
|
||||
if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||
continue
|
||||
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
||||
pred = pred.replace('_', '')
|
||||
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')
|
||||
|
|
Loading…
Reference in New Issue
Block a user