forked from mrq/DL-Art-School
add pixel shuffling for 1d cases
This commit is contained in:
parent
c42c53e75a
commit
6655f7845a
|
@ -277,3 +277,29 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
|
|||
mean_loss = F.mse_loss(sp_means, tr_means)
|
||||
var_loss = F.mse_loss(sp_vars, tr_vars)
|
||||
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
||||
|
||||
|
||||
def pixel_shuffle_1d(x, upscale_factor):
|
||||
batch_size, channels, steps = x.size()
|
||||
channels //= upscale_factor
|
||||
input_view = x.contiguous().view(batch_size, channels, upscale_factor, steps)
|
||||
shuffle_out = input_view.permute(0, 1, 3, 2).contiguous()
|
||||
return shuffle_out.view(batch_size, channels, steps * upscale_factor)
|
||||
|
||||
|
||||
def pixel_unshuffle_1d(x, downscale):
|
||||
b, c, s = x.size()
|
||||
x = x.view(b, c, s//downscale, downscale)
|
||||
x = x.permute(0,1,3,2).contiguous()
|
||||
x = x.view(b, c*downscale, s//downscale)
|
||||
return x
|
||||
|
||||
|
||||
class AudioUnshuffleInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
self.compression = opt['compression']
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
|
||||
|
|
Loading…
Reference in New Issue
Block a user