Fix collator bug

This commit is contained in:
James Betker 2022-01-01 00:33:31 -07:00
parent c28d8770c7
commit 56752f1dbc
2 changed files with 4 additions and 4 deletions

View File

@ -235,9 +235,9 @@ if __name__ == '__main__':
m = None
for i, b in tqdm(enumerate(dl)):
for ib in range(batch_sz):
save(b, i, ib, 'paired_audio')
save(b, i, ib, 'paired_audio_conditioning', 0)
save(b, i, ib, 'paired_audio_conditioning', 1)
#save(b, i, ib, 'paired_audio')
#save(b, i, ib, 'paired_audio_conditioning', 0)
#save(b, i, ib, 'paired_audio_conditioning', 1)
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
#save(b, i, ib, 'speech_audio')

View File

@ -35,7 +35,7 @@ class ZeroPadDictCollate():
first_dict = batch[0]
collated = {}
for key in first_dict.keys():
if isinstance(first_dict[key], torch.Tensor):
if isinstance(first_dict[key], torch.Tensor) and len(first_dict[key].shape) > 0:
collated[key] = self.collate_tensors(batch, key)
else:
collated[key] = self.collate_into_list(batch, key)