Add FFT injector

This commit is contained in:
James Betker 2020-10-21 22:22:00 -06:00
parent 1ef559d7ca
commit d1175f0de1

View File

@ -38,6 +38,8 @@ def create_injector(opt_inject, env):
return ForEachInjector(opt_inject, env)
elif type == 'constant':
return ConstantInjector(opt_inject, env)
elif type == 'fft':
return ImageFftInjector(opt_inject, env)
else:
raise NotImplementedError
@ -262,3 +264,35 @@ class ConstantInjector(Injector):
else:
raise NotImplementedError
return { self.opt['out']: out }
class ImageFftInjector(Injector):
def __init__(self, opt, env):
super(ImageFftInjector, self).__init__(opt, env)
self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward.
self.eps = 1e-100
def forward(self, state):
if self.forward:
fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True)
b, f, h, w, c = fftim.shape
fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w)
# Normalize across spatial dimension
mean = torch.mean(fftim, dim=(0,1))
fftim = fftim - mean
std = torch.std(fftim, dim=(0,1))
fftim = (fftim + self.eps) / std
return {self.output: fftim,
'%s_std' % (self.output,): std,
'%s_mean' % (self.output,): mean}
else:
b, f, h, w = state[self.input].shape
# First, de-normalize the FFT.
mean = state['%s_mean' % (self.input,)]
std = state['%s_std' % (self.input,)]
fftim = state[self.input] * std + mean - self.eps
# Second, recover the FFT dimensions from the given filters.
fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2)
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
return {self.output: im}