Add script for computing attention for gpt_asr

This commit is contained in:
James Betker 2021-11-07 18:42:06 -07:00
parent 3c0f2fbb21
commit a367ea3fda
4 changed files with 51 additions and 5 deletions

1
.gitignore vendored
View File

@ -21,6 +21,7 @@ data/*
*.pt *.pt
*.pth *.pth
*.pdf *.pdf
*.tsv
# template # template

View File

@ -231,7 +231,7 @@ class GptAsrHf(nn.Module):
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
def get_logits(self, mel_inputs, text_targets): def get_logits(self, mel_inputs, text_targets, get_attns):
# Pad front and back. Pad at front is the "START" token. # Pad front and back. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
@ -242,14 +242,19 @@ class GptAsrHf(nn.Module):
mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb.permute(0,2,1).contiguous()
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
emb = torch.cat([mel_emb, text_emb], dim=1) emb = torch.cat([mel_emb, text_emb], dim=1)
enc = self.gpt(inputs_embeds=emb, return_dict=True).last_hidden_state gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state
text_logits = self.final_norm(enc[:, self.max_mel_frames:]) text_logits = self.final_norm(enc[:, self.max_mel_frames:])
text_logits = self.text_head(text_logits) text_logits = self.text_head(text_logits)
text_logits = text_logits.permute(0,2,1) text_logits = text_logits.permute(0,2,1)
return text_logits return text_logits
def forward(self, mel_inputs, text_targets): def forward(self, mel_inputs, text_targets, return_attentions=False):
text_logits = self.get_logits(mel_inputs, text_targets) text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
if return_attentions:
return text_logits # These weren't really the logits.
loss_text = F.cross_entropy(text_logits, text_targets.long()) loss_text = F.cross_entropy(text_logits, text_targets.long())
return loss_text.mean(), text_logits return loss_text.mean(), text_logits

View File

@ -0,0 +1,40 @@
import os
import numpy
import torch
import torch.nn as nn
from matplotlib import pyplot
from torch.utils.tensorboard import SummaryWriter
from data.audio.unsupervised_audio_dataset import load_audio
from models.gpt_voice.gpt_asr_hf import GptAsrHf
from models.tacotron2.text import text_to_sequence
from trainer.injectors.base_injectors import MelSpectrogramInjector
if __name__ == '__main__':
audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0)
audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1]))
mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
mel = mel_inj({'in': audio_data})['out'].cuda()
actual_text = 'and it doesn\'t take very long.'
labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda()
model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8)
model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth'))
model = model.cuda()
with torch.no_grad():
attentions = model(mel, labels, return_attentions=True)
attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames]
attentions = attentions.sum(0).sum(1).squeeze()
xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)]
os.makedirs('results', exist_ok=True)
logger = SummaryWriter('results')
for e, character_attn in enumerate(attentions):
if e >= len(actual_text):
break
fig = pyplot.figure()
ax = fig.add_axes([0,0,1,1])
ax.bar(xs, character_attn.cpu().numpy())
logger.add_figure(f'{e}_{actual_text[e]}', fig)

View File

@ -57,7 +57,7 @@ class WordErrorRate:
if __name__ == '__main__': if __name__ == '__main__':
inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv' inference_tsv = 'D:\\dlas\\codes\\46000ema_8beam.tsv'
libri_base = 'Z:\\libritts\\test-clean' libri_base = 'Z:\\libritts\\test-clean'
wer = WordErrorRate() wer = WordErrorRate()