forked from mrq/DL-Art-School
Finish up mods for next version of GptAsrHf
This commit is contained in:
parent
14f3155ec4
commit
0604060580
|
@ -108,9 +108,9 @@ class TextWavLoader(torch.utils.data.Dataset):
|
|||
return {
|
||||
'real_text': text,
|
||||
'padded_text': tseq,
|
||||
'input_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
||||
'text_lengths': torch.tensor(orig_text_len, dtype=torch.long),
|
||||
'wav': wav,
|
||||
'output_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'wav_lengths': torch.tensor(orig_output, dtype=torch.long),
|
||||
'filenames': path
|
||||
}
|
||||
return tseq, wav, path, text
|
||||
|
@ -159,9 +159,9 @@ class TextMelCollate():
|
|||
|
||||
return {
|
||||
'padded_text': text_padded,
|
||||
'input_lengths': input_lengths,
|
||||
'text_lengths': input_lengths,
|
||||
'wav': wav_padded,
|
||||
'output_lengths': output_lengths,
|
||||
'wav_lengths': output_lengths,
|
||||
'filenames': filenames,
|
||||
'real_text': real_text,
|
||||
}
|
||||
|
@ -171,14 +171,14 @@ if __name__ == '__main__':
|
|||
batch_sz = 32
|
||||
params = {
|
||||
'mode': 'nv_tacotron',
|
||||
'path': 'E:\\audio\\MozillaCommonVoice\\en\\test.tsv',
|
||||
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'n_workers': 1,
|
||||
'batch_size': batch_sz,
|
||||
'fetcher_mode': 'mozilla_cv',
|
||||
'fetcher_mode': ['libritts'],
|
||||
'needs_collate': True,
|
||||
#'max_wav_length': 256000,
|
||||
#'max_text_length': 200,
|
||||
'max_wav_length': 256000,
|
||||
'max_text_length': 200,
|
||||
'sample_rate': 22050,
|
||||
}
|
||||
from data import create_dataset, create_dataloader
|
||||
|
|
|
@ -202,7 +202,8 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
|
||||
class GptAsrHf2(nn.Module):
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+1
|
||||
START_TOKEN = NUMBER_SYMBOLS
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS+2
|
||||
|
||||
def __init__(self, layers=8, model_dim=512, heads=8, max_symbols_per_phrase=800, max_mel_frames=3000, checkpointing=True):
|
||||
super().__init__()
|
||||
|
@ -230,7 +231,7 @@ class GptAsrHf2(nn.Module):
|
|||
|
||||
def get_logits(self, mel_inputs, text_targets, get_attns=False):
|
||||
# Pad front remove last element to set up next token prediction. Pad at front is the "START" token.
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.NUMBER_SYMBOLS)[:, :-1]
|
||||
text_targets = F.pad(text_targets, (1,0), value=self.START_TOKEN)[:, :-1]
|
||||
text_emb = self.gpt.get_input_embeddings()(text_targets)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_targets.device))
|
||||
mel_emb = self.mel_encoder(mel_inputs)
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torchvision.utils
|
||||
|
||||
from trainer.inject import Injector
|
||||
|
@ -34,6 +36,7 @@ def spec_augment(mel_spectrogram, frequency_masking_para=27, time_masking_para=7
|
|||
|
||||
return mel_spectrogram
|
||||
|
||||
|
||||
class MelMaskInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
|
@ -54,7 +57,8 @@ def visualization_spectrogram(spec, title):
|
|||
spec = ((spec + 1) / 2).clip(0, 1)
|
||||
torchvision.utils.save_image(spec, f'{title}.png')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def test_mel_mask():
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from trainer.injectors.base_injectors import MelSpectrogramInjector
|
||||
spec_maker = MelSpectrogramInjector({'in': 'audio', 'out': 'spec'}, {})
|
||||
|
@ -63,3 +67,57 @@ if __name__ == '__main__':
|
|||
visualization_spectrogram(s, 'original spec')
|
||||
saug = spec_augment(s, 50, 5, 1, 3)
|
||||
visualization_spectrogram(saug, 'modified spec')
|
||||
|
||||
|
||||
'''
|
||||
Crafty bespoke injector that is used when training ASR models to create longer sequences to ensure that the entire
|
||||
input length embedding is trained. Does this by concatenating every other batch element together to create longer
|
||||
sequences which (theoretically) use similar amounts of GPU memory.
|
||||
'''
|
||||
class CombineMelInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.audio_key = opt['audio_key']
|
||||
self.text_key = opt['text_key']
|
||||
self.audio_lengths = opt['audio_lengths_key']
|
||||
self.text_lengths = opt['text_lengths_key']
|
||||
from models.tacotron2.text import symbols
|
||||
self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user.
|
||||
|
||||
def forward(self, state):
|
||||
audio = state[self.audio_key]
|
||||
texts = state[self.text_key]
|
||||
audio_lengths = state[self.audio_lengths]
|
||||
text_lengths = state[self.text_lengths]
|
||||
assert audio.shape[0] % 2 == 0 # Make sure there are an even number of batches.
|
||||
combined_audios = []
|
||||
combined_texts = []
|
||||
for b in range(audio.shape[0]//2):
|
||||
a1 = audio[b*2, :audio_lengths[b*2]]
|
||||
a2 = audio[b*2+1, :audio_lengths[b*2+1]]
|
||||
a = torch.cat([a1, a2], dim=0)
|
||||
a = torch.nn.functional.pad(a, (0, audio.shape[-1]*2-a.shape[-1]))
|
||||
combined_audios.append(a)
|
||||
|
||||
t1 = texts[b*2, :text_lengths[b*2]]
|
||||
t1 = torch.nn.functional.pad(t1, (0, 1), value=self.text_separator)
|
||||
t2 = texts[b*2+1, :text_lengths[b*2+1]]
|
||||
t = torch.cat([t1, t2], dim=0)
|
||||
t = torch.nn.functional.pad(t, (0, texts.shape[-1]*2-t.shape[-1]))
|
||||
combined_texts.append(t)
|
||||
return {self.audio_key: torch.stack(combined_audios, dim=0),
|
||||
self.text_key: torch.stack(combined_texts, dim=0)}
|
||||
|
||||
|
||||
def test_mel_injector():
|
||||
inj = CombineMelInjector({'audio_key': 'a', 'text_key': 't', 'audio_lengths_key': "alk", 'text_lengths_key': 'tlk'}, {})
|
||||
a = torch.rand((4, 22000))
|
||||
al = torch.tensor([11000,14000,22000,20000])
|
||||
t = torch.randint(0, 120, (4, 250))
|
||||
tl = torch.tensor([100,120,200,250])
|
||||
rs = inj({'a': a, 't': t, 'alk': al, 'tlk': tl})
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mel_injector()
|
Loading…
Reference in New Issue
Block a user