add pixel shuffling for 1d cases

This commit is contained in:
James Betker 2022-05-04 08:03:09 -06:00
parent c42c53e75a
commit 6655f7845a

View File

@ -277,3 +277,29 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
mean_loss = F.mse_loss(sp_means, tr_means) mean_loss = F.mse_loss(sp_means, tr_means)
var_loss = F.mse_loss(sp_vars, tr_vars) var_loss = F.mse_loss(sp_vars, tr_vars)
return {self.output: mean_loss, self.var_loss_key: var_loss} return {self.output: mean_loss, self.var_loss_key: var_loss}
def pixel_shuffle_1d(x, upscale_factor):
batch_size, channels, steps = x.size()
channels //= upscale_factor
input_view = x.contiguous().view(batch_size, channels, upscale_factor, steps)
shuffle_out = input_view.permute(0, 1, 3, 2).contiguous()
return shuffle_out.view(batch_size, channels, steps * upscale_factor)
def pixel_unshuffle_1d(x, downscale):
b, c, s = x.size()
x = x.view(b, c, s//downscale, downscale)
x = x.permute(0,1,3,2).contiguous()
x = x.view(b, c*downscale, s//downscale)
return x
class AudioUnshuffleInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
self.compression = opt['compression']
def forward(self, state):
inp = state[self.input]
return {self.output: pixel_unshuffle_1d(inp, self.compression)}