diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 177bd249..53f3c309 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -88,6 +88,9 @@ def create_dataset(dataset_opt, return_collate=False): from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D elif mode == 'grand_conjoined_voice': from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D + from data.zero_pad_dict_collate import ZeroPadDictCollate as C + if opt_get(dataset_opt, ['needs_collate'], True): + collate = C() else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/audio/grand_conjoined_dataset.py b/codes/data/audio/grand_conjoined_dataset.py index aa30f24f..7101b0f9 100644 --- a/codes/data/audio/grand_conjoined_dataset.py +++ b/codes/data/audio/grand_conjoined_dataset.py @@ -55,10 +55,11 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): self.max_paired_text_length = opt['max_paired_text_length'] self.max_solo_audio_length = opt['max_solo_audio_length'] self.max_solo_text_length = opt['max_solo_text_length'] + self.collate = opt_get(opt, ['needs_collate'], False) self.sample_rate = sample_rate # Set some sane arguments for all three datasets. - paired_dataset_args['needs_collate'] = False + paired_dataset_args['needs_collate'] = self.collate paired_dataset_args['load_conditioning'] = False paired_dataset_args['sample_rate'] = sample_rate paired_dataset_args['max_wav_length'] = self.max_paired_audio_length @@ -69,13 +70,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): unsupervised_audio_args['sampling_rate'] = sample_rate unsupervised_audio_args['do_augmentation'] = False unsupervised_audio_args['resample_clip'] = False - unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length + if self.collate: + unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length self.speech = UnsupervisedAudioDataset(unsupervised_audio_args) self.text = HfDataset(**text_corpus_args) def fetch_text_at(self, i): try: txt = self.text[i % len(self.text)]['text'] + assert '*' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways. tok = self.speech_and_text.get_text(txt) padding_required = self.max_solo_text_length - tok.shape[0] if padding_required < 0: @@ -89,8 +92,23 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): # handle (e.g. ones with emojis, or other languages). Just return another one. return self.fetch_text_at((i+1) % len(self.text)) + def fetch_snt_at(self, i): + fetched = self.speech_and_text[i % len(self.speech_and_text)] + if self.collate: + tseq, wav, path, text, cond = fetched + return { + 'real_text': text, + 'padded_text': tseq, + 'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), + 'wav': wav, + 'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), + 'filenames': path + } + else: + return fetched + def __getitem__(self, i): - snt = self.speech_and_text[i % len(self.speech_and_text)] + snt = self.fetch_snt_at(i) if self.only_paired: return { 'paired_audio': snt['wav'], @@ -105,8 +123,11 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'text_tokens': snt['padded_text'], } else: - sp = self.speech[i % len(self.speech)] txt, txt_tok = self.fetch_text_at(i % len(self.text)) + sp = self.speech[i % len(self.speech)] + # Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise. + sp['clip'] = sp['clip'][:, :self.max_solo_audio_length] + sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length) return { 'paired_audio': snt['wav'], 'paired_audio_lengths': snt['wav_lengths'], @@ -114,7 +135,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): 'paired_text_tokens': snt['padded_text'], 'paired_file': snt['filenames'], 'speech_audio': sp['clip'], - 'speech_audio_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length), + 'speech_audio_lengths': sp['clip_lengths'], 'speech_file': sp['path'], 'text_text': txt, 'text_tokens': txt_tok, @@ -142,6 +163,7 @@ if __name__ == '__main__': 'paired_dataset_args': { 'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], 'fetcher_mode': ['libritts'], + 'use_bpe_tokenizer': False, }, 'unsupervised_audio_args': { 'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'], @@ -163,15 +185,17 @@ if __name__ == '__main__': 'max_solo_text_length': 330, 'max_solo_audio_length': 300000, 'only_paired': True, + 'needs_collate': True, 'paired_dataset_args': { 'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], 'fetcher_mode': ['libritts'], + 'use_bpe_tokenizer': False, }, } from data import create_dataset, create_dataloader - ds = create_dataset(val_params) - dl = create_dataloader(ds, val_params) + ds, c = create_dataset(train_params, return_collate=True) + dl = create_dataloader(ds, train_params, collate_fn=c) def save(b, i, ib, key): torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050) diff --git a/codes/data/zero_pad_dict_collate.py b/codes/data/zero_pad_dict_collate.py new file mode 100644 index 00000000..f3b91343 --- /dev/null +++ b/codes/data/zero_pad_dict_collate.py @@ -0,0 +1,42 @@ +import torch +import torch.nn.functional as F + + +class ZeroPadDictCollate(): + """ + Given a list of dictionary outputs with torch.Tensors from a Dataset, iterates through each one, finds the longest + tensor, and zero pads all the other tensors together. + """ + def collate_tensors(self, batch, key): + result = [] + largest_dims = [0 for _ in range(len(batch[0][key].shape))] + for elem in batch: + result.append(elem[key]) + largest_dims = [max(current_largest, new_consideration) for current_largest, new_consideration in zip(largest_dims, elem[key].shape)] + # Now pad each tensor by the largest dimension. + for i in range(len(result)): + padding_tuple = () + for d in range(len(largest_dims)): + padding_needed = largest_dims[d] - result[i].shape[d] + assert padding_needed >= 0 + padding_tuple = (0, padding_needed) + padding_tuple + result[i] = F.pad(result[i], padding_tuple) + + return torch.stack(result, dim=0) + + + def collate_into_list(self, batch, key): + result = [] + for elem in batch: + result.append(elem[key]) + return result + + def __call__(self, batch): + first_dict = batch[0] + collated = {} + for key in first_dict.keys(): + if isinstance(first_dict[key], torch.Tensor): + collated[key] = self.collate_tensors(batch, key) + else: + collated[key] = self.collate_into_list(batch, key) + return collated \ No newline at end of file