49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
import torch
|
|
|
|
# "long" and "short" denote longer and shorter samples
|
|
class PixelShuffle1D(torch.nn.Module):
|
|
"""
|
|
1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
|
|
Upscales sample length, downscales channel length
|
|
"short" is input, "long" is output
|
|
"""
|
|
def __init__(self, upscale_factor):
|
|
super(PixelShuffle1D, self).__init__()
|
|
self.upscale_factor = upscale_factor
|
|
|
|
def forward(self, x):
|
|
batch_size = x.shape[0]
|
|
short_channel_len = x.shape[1]
|
|
short_width = x.shape[2]
|
|
|
|
long_channel_len = short_channel_len // self.upscale_factor
|
|
long_width = self.upscale_factor * short_width
|
|
|
|
x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
|
|
x = x.permute(0, 2, 3, 1).contiguous()
|
|
x = x.view(batch_size, long_channel_len, long_width)
|
|
|
|
return x
|
|
|
|
class PixelUnshuffle1D(torch.nn.Module):
|
|
"""
|
|
Inverse of 1D pixel shuffler
|
|
Upscales channel length, downscales sample length
|
|
"long" is input, "short" is output
|
|
"""
|
|
def __init__(self, downscale_factor):
|
|
super(PixelUnshuffle1D, self).__init__()
|
|
self.downscale_factor = downscale_factor
|
|
|
|
def forward(self, x):
|
|
batch_size = x.shape[0]
|
|
long_channel_len = x.shape[1]
|
|
long_width = x.shape[2]
|
|
|
|
short_channel_len = long_channel_len * self.downscale_factor
|
|
short_width = long_width // self.downscale_factor
|
|
|
|
x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor])
|
|
x = x.permute(0, 3, 1, 2).contiguous()
|
|
x = x.view([batch_size, short_channel_len, short_width])
|
|
return x |