diff --git a/codes/data/audio/grand_conjoined_dataset.py b/codes/data/audio/grand_conjoined_dataset.py index 4d095b04..7ba5bdb7 100644 --- a/codes/data/audio/grand_conjoined_dataset.py +++ b/codes/data/audio/grand_conjoined_dataset.py @@ -44,10 +44,12 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): Performs tokenization at this level, ignoring any tokenization performed by upstream datasets. """ def __init__(self, opt): + sample_rate = 22050 # Fixed. paired_dataset_args = opt['paired_dataset_args'] - unsupervised_audio_args = opt['unsupervised_audio_args'] - text_corpus_args = opt['text_corpus_args'] - sample_rate = 22050 + self.only_paired = opt_get(opt, ['only_paired'], False) + if not self.only_paired: + unsupervised_audio_args = opt['unsupervised_audio_args'] + text_corpus_args = opt['text_corpus_args'] self.max_paired_audio_length = opt['max_paired_audio_length'] self.max_paired_text_length = opt['max_paired_text_length'] @@ -61,14 +63,15 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): paired_dataset_args['sample_rate'] = sample_rate paired_dataset_args['max_wav_length'] = self.max_paired_audio_length paired_dataset_args['max_text_length'] = self.max_paired_text_length - unsupervised_audio_args['sampling_rate'] = sample_rate - unsupervised_audio_args['do_augmentation'] = False - unsupervised_audio_args['resample_clip'] = False - unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length - self.speech_and_text = build_paired_voice_dataset(paired_dataset_args) - self.speech = UnsupervisedAudioDataset(unsupervised_audio_args) - self.text = HfDataset(**text_corpus_args) + + if not self.only_paired: + unsupervised_audio_args['sampling_rate'] = sample_rate + unsupervised_audio_args['do_augmentation'] = False + unsupervised_audio_args['resample_clip'] = False + unsupervised_audio_args['pad_to_samples'] = self.max_solo_audio_length + self.speech = UnsupervisedAudioDataset(unsupervised_audio_args) + self.text = HfDataset(**text_corpus_args) def fetch_text_at(self, i): try: @@ -76,7 +79,7 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): tok = self.speech_and_text.get_text(txt) padding_required = self.max_solo_text_length - tok.shape[0] if padding_required < 0: - # Just truncate since there is no conditioning requried. + # Just truncate since there is no conditioning required. tok = tok[:self.max_solo_text_length] elif padding_required > 0: tok = F.pad(tok, (0, padding_required)) @@ -88,29 +91,45 @@ class GrandConjoinedDataset(torch.utils.data.Dataset): def __getitem__(self, i): snt = self.speech_and_text[i % len(self.speech_and_text)] - sp = self.speech[i % len(self.speech)] - txt, txt_tok = self.fetch_text_at(i % len(self.text)) - - return { - 'paired_audio': snt['wav'], - 'paired_audio_lengths': snt['wav_lengths'], - 'paired_text': snt['real_text'], - 'paired_text_tokens': snt['padded_text'], - 'paired_file': snt['filenames'], - 'speech_audio': sp['clip'], - 'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length), - 'speech_file': sp['path'], - 'text_text': txt, - 'text_tokens': txt_tok, - } + if self.only_paired: + return { + 'paired_audio': snt['wav'], + 'paired_audio_lengths': snt['wav_lengths'], + 'paired_text': snt['real_text'], + 'paired_text_tokens': snt['padded_text'], + 'paired_file': snt['filenames'], + 'speech_audio': snt['wav'], + 'speech_lengths': snt['wav_lengths'], + 'speech_file': snt['filenames'], + 'text_text': snt['real_text'], + 'text_tokens': snt['padded_text'], + } + else: + sp = self.speech[i % len(self.speech)] + txt, txt_tok = self.fetch_text_at(i % len(self.text)) + return { + 'paired_audio': snt['wav'], + 'paired_audio_lengths': snt['wav_lengths'], + 'paired_text': snt['real_text'], + 'paired_text_tokens': snt['padded_text'], + 'paired_file': snt['filenames'], + 'speech_audio': sp['clip'], + 'speech_lengths': clamp(sp['clip_lengths'], 0, self.max_solo_audio_length), + 'speech_file': sp['path'], + 'text_text': txt, + 'text_tokens': txt_tok, + } def __len__(self): - return max(len(self.speech), len(self.speech_and_text), len(self.text)) + if self.only_paired: + return len(self.speech_and_text) + else: + return max(len(self.speech), len(self.speech_and_text), len(self.text)) if __name__ == '__main__': batch_sz = 8 - params = { + train_params = { 'mode': 'grand_conjoined_voice', 'phase': 'train', 'n_workers': 0, @@ -133,10 +152,26 @@ if __name__ == '__main__': 'cache_path': 'Z:\\huggingface_datasets\\cache', }, } + val_params = { + 'mode': 'grand_conjoined_voice', + 'phase': 'val', + 'n_workers': 0, + 'batch_size': batch_sz, + + 'max_paired_audio_length': 255995, + 'max_paired_text_length': 80, + 'max_solo_text_length': 330, + 'max_solo_audio_length': 300000, + 'only_paired': True, + 'paired_dataset_args': { + 'path': ['Z:\\bigasr_dataset\\libritts\\test-clean_list.txt'], + 'fetcher_mode': ['libritts'], + }, + } from data import create_dataset, create_dataloader - ds = create_dataset(params) - dl = create_dataloader(ds, params) + ds = create_dataset(val_params) + dl = create_dataloader(ds, val_params) def save(b, i, ib, key): torchaudio.save(f'{i}_clip_{ib}_{key}.wav', b[key][ib], 22050)