This commit is contained in:
James Betker 2022-01-01 10:31:03 -07:00
parent d5a5111890
commit 35abefd038

View File

@ -39,7 +39,7 @@ class ZeroPadDictCollate():
if len(first_dict[key].shape) > 0:
collated[key] = self.collate_tensors(batch, key)
else:
collated[key] = torch.stack(batch[key])
collated[key] = torch.stack([b[key] for b in batch])
else:
collated[key] = self.collate_into_list(batch, key)
return collated