From d1175f0de12f6f15f2aa4180556227c6cc594951 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 21 Oct 2020 22:22:00 -0600 Subject: [PATCH] Add FFT injector --- codes/models/steps/injectors.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index dc7abff7..de78aff3 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -38,6 +38,8 @@ def create_injector(opt_inject, env): return ForEachInjector(opt_inject, env) elif type == 'constant': return ConstantInjector(opt_inject, env) + elif type == 'fft': + return ImageFftInjector(opt_inject, env) else: raise NotImplementedError @@ -262,3 +264,35 @@ class ConstantInjector(Injector): else: raise NotImplementedError return { self.opt['out']: out } + + +class ImageFftInjector(Injector): + def __init__(self, opt, env): + super(ImageFftInjector, self).__init__(opt, env) + self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward. + self.eps = 1e-100 + + def forward(self, state): + if self.forward: + fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True) + b, f, h, w, c = fftim.shape + fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w) + # Normalize across spatial dimension + mean = torch.mean(fftim, dim=(0,1)) + fftim = fftim - mean + std = torch.std(fftim, dim=(0,1)) + fftim = (fftim + self.eps) / std + return {self.output: fftim, + '%s_std' % (self.output,): std, + '%s_mean' % (self.output,): mean} + else: + b, f, h, w = state[self.input].shape + # First, de-normalize the FFT. + mean = state['%s_mean' % (self.input,)] + std = state['%s_std' % (self.input,)] + fftim = state[self.input] * std + mean - self.eps + # Second, recover the FFT dimensions from the given filters. + fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2) + im = torch.irfft(fftim, signal_ndim=2, normalized=True) + return {self.output: im} +