DL-Art-School/codes/data/zero_pad_dict_collate.py
James Betker 35abefd038 More fix
2022-01-01 10:31:03 -07:00

45 lines
1.7 KiB
Python

import torch
import torch.nn.functional as F
class ZeroPadDictCollate():
"""
Given a list of dictionary outputs with torch.Tensors from a Dataset, iterates through each one, finds the longest
tensor, and zero pads all the other tensors together.
"""
def collate_tensors(self, batch, key):
result = []
largest_dims = [0 for _ in range(len(batch[0][key].shape))]
for elem in batch:
result.append(elem[key])
largest_dims = [max(current_largest, new_consideration) for current_largest, new_consideration in zip(largest_dims, elem[key].shape)]
# Now pad each tensor by the largest dimension.
for i in range(len(result)):
padding_tuple = ()
for d in range(len(largest_dims)):
padding_needed = largest_dims[d] - result[i].shape[d]
assert padding_needed >= 0
padding_tuple = (0, padding_needed) + padding_tuple
result[i] = F.pad(result[i], padding_tuple)
return torch.stack(result, dim=0)
def collate_into_list(self, batch, key):
result = []
for elem in batch:
result.append(elem[key])
return result
def __call__(self, batch):
first_dict = batch[0]
collated = {}
for key in first_dict.keys():
if isinstance(first_dict[key], torch.Tensor):
if len(first_dict[key].shape) > 0:
collated[key] = self.collate_tensors(batch, key)
else:
collated[key] = torch.stack([b[key] for b in batch])
else:
collated[key] = self.collate_into_list(batch, key)
return collated