Finish up mods for next version of GptAsrHf

This commit is contained in:
James Betker 2021-11-20 21:33:49 -07:00
parent 14f3155ec4
commit 0604060580
3 changed files with 71 additions and 12 deletions

View File

@ -108,9 +108,9 @@ class TextWavLoader(torch.utils.data.Dataset):
return {
'real_text': text,
'padded_text': tseq,
'input_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
'wav': wav,
'output_lengths': torch.tensor(orig_output, dtype=torch.long),
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
'filenames': path
}
return tseq, wav, path, text
@ -159,9 +159,9 @@ class TextMelCollate():
return {
'padded_text': text_padded,
'input_lengths': input_lengths,
'text_lengths': input_lengths,
'wav': wav_padded,
'output_lengths': output_lengths,
'wav_lengths': output_lengths,
'filenames': filenames,
'real_text': real_text,
}
@ -171,14 +171,14 @@ if __name__ == '__main__':
batch_sz = 32
params = {
'mode': 'nv_tacotron',
'path': 'E:\\audio\\MozillaCommonVoice\\en\\test.tsv',
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
'phase': 'train',
'n_workers': 0,
'n_workers': 1,
'batch_size': batch_sz,
'fetcher_mode': 'mozilla_cv',
'fetcher_mode': ['libritts'],
'needs_collate': True,
#'max_wav_length': 256000,
#'max_text_length': 200,
'max_wav_length': 256000,
'max_text_length': 200,
'sample_rate': 22050,
}
from data import create_dataset, create_dataloader

View File

@ -202,7 +202,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
class GptAsrHf2(nn.Module):
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
START_TOKEN = NUMBER_SYMBOLS
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+2
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True):
super().__init__()
@ -230,7 +231,7 @@ class GptAsrHf2(nn.Module):
def get_logits(self, mel_inputs, text_targets, get_attns=False):
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)[:, :-1]
text_targets = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
text_emb = self.gpt.get_input_embeddings()(text_targets)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
mel_emb = self.mel_encoder(mel_inputs)

View File

@ -3,6 +3,8 @@
import numpy as np
import random
import torch
import torchvision.utils
from trainer.inject import Injector
@ -34,6 +36,7 @@ def spec_augment(mel_spectrogram, frequency_masking_para=27, time_masking_para=7
return mel_spectrogram
class MelMaskInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
@ -54,7 +57,8 @@ def visualization_spectrogram(spec, title):
spec = ((spec + 1) / 2).clip(0, 1)
torchvision.utils.save_image(spec, f'{title}.png')
if __name__ == '__main__':
def test_mel_mask():
from data.audio.unsupervised_audio_dataset import load_audio
from trainer.injectors.base_injectors import MelSpectrogramInjector
spec_maker = MelSpectrogramInjector({'in': 'audio', 'out': 'spec'}, {})
@ -63,3 +67,57 @@ if __name__ == '__main__':
visualization_spectrogram(s, 'original spec')
saug = spec_augment(s, 50, 5, 1, 3)
visualization_spectrogram(saug, 'modified spec')
'''
Crafty bespoke injector that is used when training ASR models to create longer sequences to ensure that the entire
input length embedding is trained. Does this by concatenating every other batch element together to create longer
sequences which (theoretically) use similar amounts of GPU memory.
'''
class CombineMelInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.audio_key = opt['audio_key']
self.text_key = opt['text_key']
self.audio_lengths = opt['audio_lengths_key']
self.text_lengths = opt['text_lengths_key']
from models.tacotron2.text import symbols
self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user.
def forward(self, state):
audio = state[self.audio_key]
texts = state[self.text_key]
audio_lengths = state[self.audio_lengths]
text_lengths = state[self.text_lengths]
assert audio.shape[0] % 2 == 0 # Make sure there are an even number of batches.
combined_audios = []
combined_texts = []
for b in range(audio.shape[0]//2):
a1 = audio[b*2, :audio_lengths[b*2]]
a2 = audio[b*2+1, :audio_lengths[b*2+1]]
a = torch.cat([a1, a2], dim=0)
a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1]))
combined_audios.append(a)
t1 = texts[b*2, :text_lengths[b*2]]
t1 = torch.nn.functional.pad(t1, (0, 1), value=self.text_separator)
t2 = texts[b*2+1, :text_lengths[b*2+1]]
t = torch.cat([t1, t2], dim=0)
t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1]))
combined_texts.append(t)
return {self.audio_key: torch.stack(combined_audios, dim=0),
self.text_key: torch.stack(combined_texts, dim=0)}
def test_mel_injector():
inj = CombineMelInjector({'audio_key': 'a', 'text_key': 't', 'audio_lengths_key': "alk", 'text_lengths_key': 'tlk'}, {})
a = torch.rand((4, 22000))
al = torch.tensor([11000,14000,22000,20000])
t = torch.randint(0, 120, (4, 250))
tl = torch.tensor([100,120,200,250])
rs = inj({'a': a, 't': t, 'alk': al, 'tlk': tl})
if __name__ == '__main__':
test_mel_injector()