diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 7181c32d..3af418be 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -277,3 +277,29 @@ class ConditioningLatentDistributionDivergenceInjector(Injector): mean_loss = F.mse_loss(sp_means, tr_means) var_loss = F.mse_loss(sp_vars, tr_vars) return {self.output: mean_loss, self.var_loss_key: var_loss} + + +def pixel_shuffle_1d(x, upscale_factor): + batch_size, channels, steps = x.size() + channels //= upscale_factor + input_view = x.contiguous().view(batch_size, channels, upscale_factor, steps) + shuffle_out = input_view.permute(0, 1, 3, 2).contiguous() + return shuffle_out.view(batch_size, channels, steps * upscale_factor) + + +def pixel_unshuffle_1d(x, downscale): + b, c, s = x.size() + x = x.view(b, c, s//downscale, downscale) + x = x.permute(0,1,3,2).contiguous() + x = x.view(b, c*downscale, s//downscale) + return x + + +class AudioUnshuffleInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.compression = opt['compression'] + + def forward(self, state): + inp = state[self.input] + return {self.output: pixel_unshuffle_1d(inp, self.compression)}