diff --git a/codes/data/audio/grand_conjoined_dataset.py b/codes/data/audio/grand_conjoined_dataset.py index 21ad3440..45f4bb46 100644 --- a/codes/data/audio/grand_conjoined_dataset.py +++ b/codes/data/audio/grand_conjoined_dataset.py @@ -174,15 +174,15 @@ if __name__ == '__main__': 'batch_size': batch_sz, 'max_paired_audio_length': 255995, - 'max_paired_text_length': 200, - 'max_solo_text_length': 330, - 'max_solo_audio_length': 300000, - 'needs_collate': False, + 'max_paired_text_length': 100, + 'max_solo_text_length': 200, + 'max_solo_audio_length': 307195, 'num_conditioning_candidates': 1, 'conditioning_length': 44000, + 'needs_collate': True, 'paired_dataset_args': { - 'path': ['Y:\\clips\\podcasts-0-transcribed.tsv'], - 'fetcher_mode': ['tsv'], + 'path': ['Z:\\bigasr_dataset\\tedlium\\train-all.txt'], + 'fetcher_mode': ['libritts'], 'use_bpe_tokenizer': False, }, 'unsupervised_audio_args': { diff --git a/codes/data/zero_pad_dict_collate.py b/codes/data/zero_pad_dict_collate.py index 8d42aea5..5423f63a 100644 --- a/codes/data/zero_pad_dict_collate.py +++ b/codes/data/zero_pad_dict_collate.py @@ -35,8 +35,11 @@ class ZeroPadDictCollate(): first_dict = batch[0] collated = {} for key in first_dict.keys(): - if isinstance(first_dict[key], torch.Tensor) and len(first_dict[key].shape) > 0: - collated[key] = self.collate_tensors(batch, key) + if isinstance(first_dict[key], torch.Tensor): + if len(first_dict[key].shape) > 0: + collated[key] = self.collate_tensors(batch, key) + else: + collated[key] = torch.stack(batch[key]) else: collated[key] = self.collate_into_list(batch, key) return collated \ No newline at end of file