forked from mrq/DL-Art-School
add pixel shuffling for 1d cases
This commit is contained in:
parent
c42c53e75a
commit
6655f7845a
|
@ -277,3 +277,29 @@ class ConditioningLatentDistributionDivergenceInjector(Injector):
|
||||||
mean_loss = F.mse_loss(sp_means, tr_means)
|
mean_loss = F.mse_loss(sp_means, tr_means)
|
||||||
var_loss = F.mse_loss(sp_vars, tr_vars)
|
var_loss = F.mse_loss(sp_vars, tr_vars)
|
||||||
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
return {self.output: mean_loss, self.var_loss_key: var_loss}
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_shuffle_1d(x, upscale_factor):
|
||||||
|
batch_size, channels, steps = x.size()
|
||||||
|
channels //= upscale_factor
|
||||||
|
input_view = x.contiguous().view(batch_size, channels, upscale_factor, steps)
|
||||||
|
shuffle_out = input_view.permute(0, 1, 3, 2).contiguous()
|
||||||
|
return shuffle_out.view(batch_size, channels, steps * upscale_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_unshuffle_1d(x, downscale):
|
||||||
|
b, c, s = x.size()
|
||||||
|
x = x.view(b, c, s//downscale, downscale)
|
||||||
|
x = x.permute(0,1,3,2).contiguous()
|
||||||
|
x = x.view(b, c*downscale, s//downscale)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AudioUnshuffleInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super().__init__(opt, env)
|
||||||
|
self.compression = opt['compression']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
inp = state[self.input]
|
||||||
|
return {self.output: pixel_unshuffle_1d(inp, self.compression)}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user