forked from mrq/DL-Art-School
41 lines
1.8 KiB
Python
41 lines
1.8 KiB
Python
|
import os
|
||
|
|
||
|
import numpy
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from matplotlib import pyplot
|
||
|
from torch.utils.tensorboard import SummaryWriter
|
||
|
|
||
|
from data.audio.unsupervised_audio_dataset import load_audio
|
||
|
from models.gpt_voice.gpt_asr_hf import GptAsrHf
|
||
|
from models.tacotron2.text import text_to_sequence
|
||
|
from trainer.injectors.base_injectors import MelSpectrogramInjector
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
audio_data = load_audio('Z:\\split\\classified\\fine\\books1\\2_dchha03 The Organization of Peace\\00010.wav', 22050).unsqueeze(0)
|
||
|
audio_data = torch.nn.functional.pad(audio_data, (0, 358395-audio_data.shape[-1]))
|
||
|
mel_inj = MelSpectrogramInjector({'in': 'in', 'out': 'out'}, {})
|
||
|
mel = mel_inj({'in': audio_data})['out'].cuda()
|
||
|
actual_text = 'and it doesn\'t take very long.'
|
||
|
labels = torch.IntTensor(text_to_sequence(actual_text, ['english_cleaners'])).unsqueeze(0).cuda()
|
||
|
|
||
|
model = GptAsrHf(layers=12, model_dim=512, max_mel_frames=1400, max_symbols_per_phrase=250, heads=8)
|
||
|
model.load_state_dict(torch.load('X:\\dlas\\experiments\\train_gpt_asr_mass_hf\\models\\31000_gpt_ema.pth'))
|
||
|
model = model.cuda()
|
||
|
|
||
|
with torch.no_grad():
|
||
|
attentions = model(mel, labels, return_attentions=True)
|
||
|
attentions = torch.stack(attentions, dim=0).permute(0,1,2,4,3)[:, :, :, -model.max_symbols_per_phrase:, :model.max_mel_frames]
|
||
|
attentions = attentions.sum(0).sum(1).squeeze()
|
||
|
|
||
|
xs = [str(i) for i in range(1, model.max_mel_frames+1, 1)]
|
||
|
os.makedirs('results', exist_ok=True)
|
||
|
logger = SummaryWriter('results')
|
||
|
for e, character_attn in enumerate(attentions):
|
||
|
if e >= len(actual_text):
|
||
|
break
|
||
|
fig = pyplot.figure()
|
||
|
ax = fig.add_axes([0,0,1,1])
|
||
|
ax.bar(xs, character_attn.cpu().numpy())
|
||
|
logger.add_figure(f'{e}_{actual_text[e]}', fig)
|