forked from mrq/DL-Art-School
Add script for computing attention for gpt_asr
This commit is contained in:
parent
3c0f2fbb21
commit
a367ea3fda
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -21,6 +21,7 @@ data/*
|
|||
*.pt
|
||||
*.pth
|
||||
*.pdf
|
||||
*.tsv
|
||||
|
||||
# template
|
||||
|
||||
|
|
|
@ -231,7 +231,7 @@ class GptAsrHf(nn.Module):
|
|||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
||||
|
||||
def get_logits(self, mel_inputs, text_targets):
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns):
|
||||
# Pad front and back. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)
|
||||
text_targets = F.pad(text_targets, (0, self.max_symbols_per_phrase - text_targets.shape[1]))
|
||||
|
@ -242,14 +242,19 @@ class GptAsrHf(nn.Module):
|
|||
mel_emb = mel_emb.permute(0,2,1).contiguous()
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_emb.shape[1], device=mel_emb.device))
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
enc = self.gpt(inputs_embeds=emb, return_dict=True).last_hidden_state
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
enc = gpt_out.last_hidden_state
|
||||
text_logits = self.final_norm(enc[:, self.max_mel_frames:])
|
||||
text_logits = self.text_head(text_logits)
|
||||
text_logits = text_logits.permute(0,2,1)
|
||||
return text_logits
|
||||
|
||||
def forward(self, mel_inputs, text_targets):
|
||||
text_logits = self.get_logits(mel_inputs, text_targets)
|
||||
def forward(self, mel_inputs, text_targets, return_attentions=False):
|
||||
text_logits = self.get_logits(mel_inputs, text_targets, get_attns=return_attentions)
|
||||
if return_attentions:
|
||||
return text_logits # These weren't really the logits.
|
||||
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
||||
return loss_text.mean(), text_logits
|
||||
|
||||
|
|
40
codes/scripts/audio/compute_gpt_attention.py
Normal file
40
codes/scripts/audio/compute_gpt_attention.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import os
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from matplotlib import pyplot
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.gpt_voice.gpt_asr_hf import GptAsrHf
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from trainer.injectors.base_injectors import MelSpectrogramInjector
|
||||
|
||||
if __name__ == '__main__':
|
||||
audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0)
|
||||
audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1]))
|
||||
mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
|
||||
mel = mel_inj({'in': audio_data})['out'].cuda()
|
||||
actual_text = 'and it doesn\'t take very long.'
|
||||
labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda()
|
||||
|
||||
model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8)
|
||||
model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth'))
|
||||
model = model.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
attentions = model(mel, labels, return_attentions=True)
|
||||
attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames]
|
||||
attentions = attentions.sum(0).sum(1).squeeze()
|
||||
|
||||
xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)]
|
||||
os.makedirs('results', exist_ok=True)
|
||||
logger = SummaryWriter('results')
|
||||
for e, character_attn in enumerate(attentions):
|
||||
if e >= len(actual_text):
|
||||
break
|
||||
fig = pyplot.figure()
|
||||
ax = fig.add_axes([0,0,1,1])
|
||||
ax.bar(xs, character_attn.cpu().numpy())
|
||||
logger.add_figure(f'{e}_{actual_text[e]}', fig)
|
|
@ -57,7 +57,7 @@ class WordErrorRate:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
inference_tsv = 'D:\\dlas\\codes\\31000ema_8_beam.tsv'
|
||||
inference_tsv = 'D:\\dlas\\codes\\46000ema_8beam.tsv'
|
||||
libri_base = 'Z:\\libritts\\test-clean'
|
||||
|
||||
wer = WordErrorRate()
|
||||
|
|
Loading…
Reference in New Issue
Block a user