diff --git a/.gitignore b/.gitignore index be2f72bd..baea0e89 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ data/* *.pt *.pth *.pdf +*.tsv # template diff --git a/codes/models/gpt_voice/gpt_asr_hf.py b/codes/models/gpt_voice/gpt_asr_hf.py index 9a486e7b..4439a1d4 100644 --- a/codes/models/gpt_voice/gpt_asr_hf.py +++ b/codes/models/gpt_voice/gpt_asr_hf.py @@ -231,7 +231,7 @@ class GptAsrHf(nn.Module): self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) - def get_logits(self, mel_inputs, text_targets): + def get_logits(self, mel_inputs, text_targets, get_attns): # Pad front and back. Pad at front is the "START" token. text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS) text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1])) @@ -242,14 +242,19 @@ class GptAsrHf(nn.Module): mel_emb = mel_emb.permute(0,2,1).contiguous() mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device)) emb = torch.cat([mel_emb, text_emb], dim=1) - enc = self.gpt(inputs_embeds=emb, return_dict=True).last_hidden_state + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + enc = gpt_out.last_hidden_state text_logits = self.final_norm(enc[:, self.max_mel_frames:]) text_logits = self.text_head(text_logits) text_logits = text_logits.permute(0,2,1) return text_logits - def forward(self, mel_inputs, text_targets): - text_logits = self.get_logits(mel_inputs, text_targets) + def forward(self, mel_inputs, text_targets, return_attentions=False): + text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions) + if return_attentions: + return text_logits # These weren't really the logits. loss_text = F.cross_entropy(text_logits, text_targets.long()) return loss_text.mean(), text_logits diff --git a/codes/scripts/audio/compute_gpt_attention.py b/codes/scripts/audio/compute_gpt_attention.py new file mode 100644 index 00000000..520a6998 --- /dev/null +++ b/codes/scripts/audio/compute_gpt_attention.py @@ -0,0 +1,40 @@ +import os + +import numpy +import torch +import torch.nn as nn +from matplotlib import pyplot +from torch.utils.tensorboard import SummaryWriter + +from data.audio.unsupervised_audio_dataset import load_audio +from models.gpt_voice.gpt_asr_hf import GptAsrHf +from models.tacotron2.text import text_to_sequence +from trainer.injectors.base_injectors import MelSpectrogramInjector + +if __name__ == '__main__': + audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0) + audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1])) + mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {}) + mel = mel_inj({'in': audio_data})['out'].cuda() + actual_text = 'and it doesn\'t take very long.' + labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda() + + model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8) + model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth')) + model = model.cuda() + + with torch.no_grad(): + attentions = model(mel, labels, return_attentions=True) + attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames] + attentions = attentions.sum(0).sum(1).squeeze() + + xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)] + os.makedirs('results', exist_ok=True) + logger = SummaryWriter('results') + for e, character_attn in enumerate(attentions): + if e >= len(actual_text): + break + fig = pyplot.figure() + ax = fig.add_axes([0,0,1,1]) + ax.bar(xs, character_attn.cpu().numpy()) + logger.add_figure(f'{e}_{actual_text[e]}', fig) diff --git a/codes/scripts/audio/word_error_rate.py b/codes/scripts/audio/word_error_rate.py index 5f663230..c67032b6 100644 --- a/codes/scripts/audio/word_error_rate.py +++ b/codes/scripts/audio/word_error_rate.py @@ -57,7 +57,7 @@ class WordErrorRate: if __name__ == '__main__': - inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv' + inference_tsv = 'D:\\dlas\\codes\\46000ema_8beam.tsv' libri_base = 'Z:\\libritts\\test-clean' wer = WordErrorRate()