forked from mrq/DL-Art-School
grand conjoined dataset: support collating
This commit is contained in:
parent
8a02ba5935
commit
53784ec806
|
@ -88,6 +88,9 @@ def create_dataset(dataset_opt, return_collate=False):
|
||||||
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
|
from data.audio.audio_with_noise_dataset import AudioWithNoiseDataset as D
|
||||||
elif mode == 'grand_conjoined_voice':
|
elif mode == 'grand_conjoined_voice':
|
||||||
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
|
from data.audio.grand_conjoined_dataset import GrandConjoinedDataset as D
|
||||||
|
from data.zero_pad_dict_collate import ZeroPadDictCollate as C
|
||||||
|
if opt_get(dataset_opt, ['needs_collate'], True):
|
||||||
|
collate = C()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||||
dataset = D(dataset_opt)
|
dataset = D(dataset_opt)
|
||||||
|
|
|
@ -55,10 +55,11 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
self.max_paired_text_length = opt['max_paired_text_length']
|
self.max_paired_text_length = opt['max_paired_text_length']
|
||||||
self.max_solo_audio_length = opt['max_solo_audio_length']
|
self.max_solo_audio_length = opt['max_solo_audio_length']
|
||||||
self.max_solo_text_length = opt['max_solo_text_length']
|
self.max_solo_text_length = opt['max_solo_text_length']
|
||||||
|
self.collate = opt_get(opt, ['needs_collate'], False)
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
|
|
||||||
# Set some sane arguments for all three datasets.
|
# Set some sane arguments for all three datasets.
|
||||||
paired_dataset_args['needs_collate'] = False
|
paired_dataset_args['needs_collate'] = self.collate
|
||||||
paired_dataset_args['load_conditioning'] = False
|
paired_dataset_args['load_conditioning'] = False
|
||||||
paired_dataset_args['sample_rate'] = sample_rate
|
paired_dataset_args['sample_rate'] = sample_rate
|
||||||
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
|
paired_dataset_args['max_wav_length'] = self.max_paired_audio_length
|
||||||
|
@ -69,6 +70,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
unsupervised_audio_args['sampling_rate'] = sample_rate
|
unsupervised_audio_args['sampling_rate'] = sample_rate
|
||||||
unsupervised_audio_args['do_augmentation'] = False
|
unsupervised_audio_args['do_augmentation'] = False
|
||||||
unsupervised_audio_args['resample_clip'] = False
|
unsupervised_audio_args['resample_clip'] = False
|
||||||
|
if self.collate:
|
||||||
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
|
unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length
|
||||||
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
|
self.speech = UnsupervisedAudioDataset(unsupervised_audio_args)
|
||||||
self.text = HfDataset(**text_corpus_args)
|
self.text = HfDataset(**text_corpus_args)
|
||||||
|
@ -76,6 +78,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
def fetch_text_at(self, i):
|
def fetch_text_at(self, i):
|
||||||
try:
|
try:
|
||||||
txt = self.text[i % len(self.text)]['text']
|
txt = self.text[i % len(self.text)]['text']
|
||||||
|
assert '*' not in txt # This is a hack to get around the use of '*' to mask expletives in some text-only datasets. There really isn't a linguistic use for this character anyways.
|
||||||
tok = self.speech_and_text.get_text(txt)
|
tok = self.speech_and_text.get_text(txt)
|
||||||
padding_required = self.max_solo_text_length - tok.shape[0]
|
padding_required = self.max_solo_text_length - tok.shape[0]
|
||||||
if padding_required < 0:
|
if padding_required < 0:
|
||||||
|
@ -89,8 +92,23 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
# handle (e.g. ones with emojis, or other languages). Just return another one.
|
# handle (e.g. ones with emojis, or other languages). Just return another one.
|
||||||
return self.fetch_text_at((i+1) % len(self.text))
|
return self.fetch_text_at((i+1) % len(self.text))
|
||||||
|
|
||||||
|
def fetch_snt_at(self, i):
|
||||||
|
fetched = self.speech_and_text[i % len(self.speech_and_text)]
|
||||||
|
if self.collate:
|
||||||
|
tseq, wav, path, text, cond = fetched
|
||||||
|
return {
|
||||||
|
'real_text': text,
|
||||||
|
'padded_text': tseq,
|
||||||
|
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||||
|
'wav': wav,
|
||||||
|
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||||
|
'filenames': path
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return fetched
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
snt = self.speech_and_text[i % len(self.speech_and_text)]
|
snt = self.fetch_snt_at(i)
|
||||||
if self.only_paired:
|
if self.only_paired:
|
||||||
return {
|
return {
|
||||||
'paired_audio': snt['wav'],
|
'paired_audio': snt['wav'],
|
||||||
|
@ -105,8 +123,11 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
'text_tokens': snt['padded_text'],
|
'text_tokens': snt['padded_text'],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
sp = self.speech[i % len(self.speech)]
|
|
||||||
txt, txt_tok = self.fetch_text_at(i % len(self.text))
|
txt, txt_tok = self.fetch_text_at(i % len(self.text))
|
||||||
|
sp = self.speech[i % len(self.speech)]
|
||||||
|
# Set upper bound on solo speech lengths. This is handled automatically when collation is turned off, but needs to be done otherwise.
|
||||||
|
sp['clip'] = sp['clip'][:, :self.max_solo_audio_length]
|
||||||
|
sp['clip_lengths'] = clamp(sp['clip_lengths'], 0, self.max_solo_audio_length)
|
||||||
return {
|
return {
|
||||||
'paired_audio': snt['wav'],
|
'paired_audio': snt['wav'],
|
||||||
'paired_audio_lengths': snt['wav_lengths'],
|
'paired_audio_lengths': snt['wav_lengths'],
|
||||||
|
@ -114,7 +135,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset):
|
||||||
'paired_text_tokens': snt['padded_text'],
|
'paired_text_tokens': snt['padded_text'],
|
||||||
'paired_file': snt['filenames'],
|
'paired_file': snt['filenames'],
|
||||||
'speech_audio': sp['clip'],
|
'speech_audio': sp['clip'],
|
||||||
'speech_audio_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length),
|
'speech_audio_lengths': sp['clip_lengths'],
|
||||||
'speech_file': sp['path'],
|
'speech_file': sp['path'],
|
||||||
'text_text': txt,
|
'text_text': txt,
|
||||||
'text_tokens': txt_tok,
|
'text_tokens': txt_tok,
|
||||||
|
@ -142,6 +163,7 @@ if __name__ == '__main__':
|
||||||
'paired_dataset_args': {
|
'paired_dataset_args': {
|
||||||
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
||||||
'fetcher_mode': ['libritts'],
|
'fetcher_mode': ['libritts'],
|
||||||
|
'use_bpe_tokenizer': False,
|
||||||
},
|
},
|
||||||
'unsupervised_audio_args': {
|
'unsupervised_audio_args': {
|
||||||
'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'],
|
'path': ['Z:\\bigasr_dataset\\librispeech\\test_clean'],
|
||||||
|
@ -163,15 +185,17 @@ if __name__ == '__main__':
|
||||||
'max_solo_text_length': 330,
|
'max_solo_text_length': 330,
|
||||||
'max_solo_audio_length': 300000,
|
'max_solo_audio_length': 300000,
|
||||||
'only_paired': True,
|
'only_paired': True,
|
||||||
|
'needs_collate': True,
|
||||||
'paired_dataset_args': {
|
'paired_dataset_args': {
|
||||||
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'],
|
||||||
'fetcher_mode': ['libritts'],
|
'fetcher_mode': ['libritts'],
|
||||||
|
'use_bpe_tokenizer': False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
ds = create_dataset(val_params)
|
ds, c = create_dataset(train_params, return_collate=True)
|
||||||
dl = create_dataloader(ds, val_params)
|
dl = create_dataloader(ds, train_params, collate_fn=c)
|
||||||
|
|
||||||
def save(b, i, ib, key):
|
def save(b, i, ib, key):
|
||||||
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)
|
||||||
|
|
42
codes/data/zero_pad_dict_collate.py
Normal file
42
codes/data/zero_pad_dict_collate.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroPadDictCollate():
|
||||||
|
"""
|
||||||
|
Given a list of dictionary outputs with torch.Tensors from a Dataset, iterates through each one, finds the longest
|
||||||
|
tensor, and zero pads all the other tensors together.
|
||||||
|
"""
|
||||||
|
def collate_tensors(self, batch, key):
|
||||||
|
result = []
|
||||||
|
largest_dims = [0 for _ in range(len(batch[0][key].shape))]
|
||||||
|
for elem in batch:
|
||||||
|
result.append(elem[key])
|
||||||
|
largest_dims = [max(current_largest, new_consideration) for current_largest, new_consideration in zip(largest_dims, elem[key].shape)]
|
||||||
|
# Now pad each tensor by the largest dimension.
|
||||||
|
for i in range(len(result)):
|
||||||
|
padding_tuple = ()
|
||||||
|
for d in range(len(largest_dims)):
|
||||||
|
padding_needed = largest_dims[d] - result[i].shape[d]
|
||||||
|
assert padding_needed >= 0
|
||||||
|
padding_tuple = (0, padding_needed) + padding_tuple
|
||||||
|
result[i] = F.pad(result[i], padding_tuple)
|
||||||
|
|
||||||
|
return torch.stack(result, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_into_list(self, batch, key):
|
||||||
|
result = []
|
||||||
|
for elem in batch:
|
||||||
|
result.append(elem[key])
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
first_dict = batch[0]
|
||||||
|
collated = {}
|
||||||
|
for key in first_dict.keys():
|
||||||
|
if isinstance(first_dict[key], torch.Tensor):
|
||||||
|
collated[key] = self.collate_tensors(batch, key)
|
||||||
|
else:
|
||||||
|
collated[key] = self.collate_into_list(batch, key)
|
||||||
|
return collated
|
Loading…
Reference in New Issue
Block a user