forked from mrq/DL-Art-School
Fix collator bug
This commit is contained in:
parent
c28d8770c7
commit
56752f1dbc
|
@ -235,9 +235,9 @@ if __name__ == '__main__':
|
|||
m = None
|
||||
for i, b in tqdm(enumerate(dl)):
|
||||
for ib in range(batch_sz):
|
||||
save(b, i, ib, 'paired_audio')
|
||||
save(b, i, ib, 'paired_audio_conditioning', 0)
|
||||
save(b, i, ib, 'paired_audio_conditioning', 1)
|
||||
#save(b, i, ib, 'paired_audio')
|
||||
#save(b, i, ib, 'paired_audio_conditioning', 0)
|
||||
#save(b, i, ib, 'paired_audio_conditioning', 1)
|
||||
print(f'Paired file: {b["paired_file"][ib]} text: {b["paired_text"][ib]}')
|
||||
print(f'Paired text decoded: {decode(b, ib, "paired_text_tokens")}')
|
||||
#save(b, i, ib, 'speech_audio')
|
||||
|
|
|
@ -35,7 +35,7 @@ class ZeroPadDictCollate():
|
|||
first_dict = batch[0]
|
||||
collated = {}
|
||||
for key in first_dict.keys():
|
||||
if isinstance(first_dict[key], torch.Tensor):
|
||||
if isinstance(first_dict[key], torch.Tensor) and len(first_dict[key].shape) > 0:
|
||||
collated[key] = self.collate_tensors(batch, key)
|
||||
else:
|
||||
collated[key] = self.collate_into_list(batch, key)
|
||||
|
|
Loading…
Reference in New Issue
Block a user