import torch # "long" and "short" denote longer and shorter samples class PixelShuffle1D(torch.nn.Module): """ 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf Upscales sample length, downscales channel length "short" is input, "long" is output """ def __init__(self, upscale_factor): super(PixelShuffle1D, self).__init__() self.upscale_factor = upscale_factor def forward(self, x): batch_size = x.shape[0] short_channel_len = x.shape[1] short_width = x.shape[2] long_channel_len = short_channel_len // self.upscale_factor long_width = self.upscale_factor * short_width x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width]) x = x.permute(0, 2, 3, 1).contiguous() x = x.view(batch_size, long_channel_len, long_width) return x class PixelUnshuffle1D(torch.nn.Module): """ Inverse of 1D pixel shuffler Upscales channel length, downscales sample length "long" is input, "short" is output """ def __init__(self, downscale_factor): super(PixelUnshuffle1D, self).__init__() self.downscale_factor = downscale_factor def forward(self, x): batch_size = x.shape[0] long_channel_len = x.shape[1] long_width = x.shape[2] short_channel_len = long_channel_len * self.downscale_factor short_width = long_width // self.downscale_factor x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor]) x = x.permute(0, 3, 1, 2).contiguous() x = x.view([batch_size, short_channel_len, short_width]) return x