diff --git a/codes/data/audio/grand_conjoined_dataset.py b/codes/data/audio/grand_conjoined_dataset.py index e0e1bad6..c0debe6e 100644 --- a/codes/data/audio/grand_conjoined_dataset.py +++ b/codes/data/audio/grand_conjoined_dataset.py @@ -235,9 +235,9 @@ if __name__ == '__main__': m = None for i, b in tqdm(enumerate(dl)): for ib in range(batch_sz): - save(b, i, ib, 'paired_audio') - save(b, i, ib, 'paired_audio_conditioning', 0) - save(b, i, ib, 'paired_audio_conditioning', 1) + #save(b, i, ib, 'paired_audio') + #save(b, i, ib, 'paired_audio_conditioning', 0) + #save(b, i, ib, 'paired_audio_conditioning', 1) print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}') print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}') #save(b, i, ib, 'speech_audio') diff --git a/codes/data/zero_pad_dict_collate.py b/codes/data/zero_pad_dict_collate.py index f3b91343..8d42aea5 100644 --- a/codes/data/zero_pad_dict_collate.py +++ b/codes/data/zero_pad_dict_collate.py @@ -35,7 +35,7 @@ class ZeroPadDictCollate(): first_dict = batch[0] collated = {} for key in first_dict.keys(): - if isinstance(first_dict[key], torch.Tensor): + if isinstance(first_dict[key], torch.Tensor) and len(first_dict[key].shape) > 0: collated[key] = self.collate_tensors(batch, key) else: collated[key] = self.collate_into_list(batch, key)