Check in GPT with new inference methods (but not the backing code..)
This commit is contained in:
parent
0822792d79
commit
986fc9628d
|
@ -1,3 +1,5 @@
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -92,12 +94,12 @@ class GptAsr(nn.Module):
|
||||||
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
loss_text = F.cross_entropy(text_logits[:,:,:-1], text_targets[:,1:].long())
|
||||||
return loss_text.mean(), text_logits
|
return loss_text.mean(), text_logits
|
||||||
|
|
||||||
def inference_beam_topk(self, mel):
|
def inference_beam_topk(self, mel, fn='inference_beam'):
|
||||||
def topk_sampler(distribution, k):
|
def topk_sampler(distribution, k):
|
||||||
return torch.topk(distribution, k=k, dim=-1)
|
return torch.topk(distribution, k=k, dim=-1)
|
||||||
return self.inference_beam(mel, topk_sampler)
|
return getattr(self, fn)(mel, topk_sampler)
|
||||||
|
|
||||||
def inference_beam_sampled(self, mel):
|
def inference_beam_sampled(self, mel, fn='inference_beam'):
|
||||||
def multinomial_sampler(distribution, k):
|
def multinomial_sampler(distribution, k):
|
||||||
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
|
indices = torch.multinomial(distribution, num_samples=k, replacement=False)
|
||||||
values = torch.gather(distribution, dim=1, index=indices)
|
values = torch.gather(distribution, dim=1, index=indices)
|
||||||
|
@ -106,7 +108,7 @@ class GptAsr(nn.Module):
|
||||||
self.indices = i
|
self.indices = i
|
||||||
self.values = v
|
self.values = v
|
||||||
return container(indices, values)
|
return container(indices, values)
|
||||||
return self.inference_beam(mel, multinomial_sampler)
|
return getattr(self, fn)(mel, multinomial_sampler)
|
||||||
|
|
||||||
def inference_beam(self, mel_inputs, sampler_fn):
|
def inference_beam(self, mel_inputs, sampler_fn):
|
||||||
beam_width = 16
|
beam_width = 16
|
||||||
|
@ -151,33 +153,90 @@ class GptAsr(nn.Module):
|
||||||
return text_seq
|
return text_seq
|
||||||
|
|
||||||
|
|
||||||
|
def inference_beam_opt(self, mel_inputs, sampler_fn):
|
||||||
|
beam_width = 16
|
||||||
|
temperature = .8
|
||||||
|
|
||||||
|
b, _, s = mel_inputs.shape
|
||||||
|
assert b == 1 # Beam search only works on batches of one.
|
||||||
|
mel_emb = self.mel_encoder(mel_inputs)
|
||||||
|
mel_emb = F.pad(mel_emb, (0, self.max_mel_frames - mel_emb.shape[-1]))
|
||||||
|
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||||
|
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||||
|
|
||||||
|
intermediates = []
|
||||||
|
text_seq = torch.full((b,1), fill_value=self.NUMBER_SYMBOLS, device=mel_emb.device)
|
||||||
|
probabilities = torch.ones((b,), device=mel_emb.device)
|
||||||
|
while text_seq.shape[-1] < self.max_symbols_per_phrase:
|
||||||
|
text_emb = self.text_embedding(text_seq)
|
||||||
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=mel_emb.device))
|
||||||
|
if text_emb.shape[0] != mel_emb.shape[0]:
|
||||||
|
mel_emb = mel_emb.repeat(text_emb.shape[0], 1, 1)
|
||||||
|
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||||
|
|
||||||
|
if len(intermediates) == 0:
|
||||||
|
enc, intermediates = self.gpt(emb, return_intermediates=True)
|
||||||
|
intermediates = [(i[0].repeat(beam_width, 1, 1),
|
||||||
|
i[1].repeat(beam_width, 1, 1)) for i in intermediates]
|
||||||
|
else:
|
||||||
|
enc, intermediates = self.gpt.infer_last_two(emb, intermediates)
|
||||||
|
|
||||||
|
text_logits = self.final_norm(enc[:, mel_emb.shape[1]:])
|
||||||
|
text_logits = self.text_head(text_logits)
|
||||||
|
topk = sampler_fn(F.softmax(temperature * text_logits[:, -1], dim=-1), k=beam_width)
|
||||||
|
probabilities = (probabilities.repeat_interleave(beam_width, dim=0) * topk.values.flatten())
|
||||||
|
probabilities, sort_indices = torch.sort(probabilities, descending=True)
|
||||||
|
probabilities = probabilities[:beam_width]
|
||||||
|
|
||||||
|
text_seq = text_seq.repeat_interleave(beam_width, dim=0)
|
||||||
|
codes = topk.indices.flatten()
|
||||||
|
text_seq = torch.cat([text_seq, codes.unsqueeze(1)], dim=1)
|
||||||
|
text_seq = text_seq[sort_indices]
|
||||||
|
text_seq = text_seq[:beam_width]
|
||||||
|
|
||||||
|
# PAD doubles as a stop token. PAD=0.
|
||||||
|
if torch.all(torch.any(text_seq == 0, dim=1)):
|
||||||
|
break
|
||||||
|
|
||||||
|
if text_seq.shape[1] >= self.max_mel_frames:
|
||||||
|
print("Warning! Encountered frame limit before a pad token. Output is likely wrong.")
|
||||||
|
|
||||||
|
return text_seq
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_gpt_asr(opt_net, opt):
|
def register_gpt_asr(opt_net, opt):
|
||||||
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
return GptAsr(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
# Halves the number of layers in the provided model.
|
# Quick script that loads a model and halves the number of layers, then saves that model.
|
||||||
def distill(model):
|
def distill():
|
||||||
|
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
||||||
|
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
||||||
rc = 0
|
rc = 0
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(model.gpt.layers.layers):
|
while i < len(gpt.gpt.layers.layers):
|
||||||
if rc % 2 != 0:
|
if rc % 2 != 0:
|
||||||
del model.gpt.layers.layers[i]
|
del gpt.gpt.layers.layers[i]
|
||||||
else:
|
else:
|
||||||
i += 1
|
i += 1
|
||||||
rc += 1
|
rc += 1
|
||||||
return model
|
torch.save(gpt.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
gpt = GptAsr(max_symbols_per_phrase=250, max_mel_frames=1400, layers=12, model_dim=768, heads=12)
|
gpt = GptAsr(max_symbols_per_phrase=100, max_mel_frames=200, layers=6, model_dim=256, heads=2).cuda()
|
||||||
gpt.load_state_dict(torch.load('../experiments/train_gpt_asr_mass/models/21500_mel_gen.pth'))
|
|
||||||
student = distill(gpt)
|
|
||||||
torch.save(student.state_dict(), '../experiments/train_gpt_asr_mass/models/21500_mel_gen_distilled.pth')
|
|
||||||
#l = gpt(torch.randn(2,80,800),
|
#l = gpt(torch.randn(2,80,800),
|
||||||
# torch.randint(high=len(symbols), size=(2,180)))
|
# torch.randint(high=len(symbols), size=(2,180)))
|
||||||
|
|
||||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
with torch.no_grad():
|
||||||
#print(o.shape)
|
t = torch.randn(1,80,800).cuda()
|
||||||
|
start = time()
|
||||||
|
s = gpt.inference_beam_topk(t)
|
||||||
|
print(time()-start)
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
o = gpt.inference_beam_topk(t, fn='inference_beam_opt')
|
||||||
|
print(time()-start)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
want_metrics = False
|
want_metrics = False
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_deepspeech_libri.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_asr_mass.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
utils.util.loaded_options = opt
|
utils.util.loaded_options = opt
|
||||||
|
@ -71,8 +71,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
tq = tqdm(test_loader)
|
tq = tqdm(test_loader)
|
||||||
for data in tq:
|
for data in tq:
|
||||||
#if data['clips'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
if data['clip'].shape[-1] > opt['networks']['asr_gen']['kwargs']['max_mel_frames']*255:
|
||||||
# continue
|
continue
|
||||||
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
pred = forward_pass(model, data, dataset_dir, opt, batch)
|
||||||
pred = pred.replace('_', '')
|
pred = pred.replace('_', '')
|
||||||
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')
|
output.write(f'{pred}\t{os.path.basename(data["path"][0])}\n')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user