Add FFT injector
This commit is contained in:
parent
1ef559d7ca
commit
d1175f0de1
|
@ -38,6 +38,8 @@ def create_injector(opt_inject, env):
|
|||
return ForEachInjector(opt_inject, env)
|
||||
elif type == 'constant':
|
||||
return ConstantInjector(opt_inject, env)
|
||||
elif type == 'fft':
|
||||
return ImageFftInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -262,3 +264,35 @@ class ConstantInjector(Injector):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
return { self.opt['out']: out }
|
||||
|
||||
|
||||
class ImageFftInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ImageFftInjector, self).__init__(opt, env)
|
||||
self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward.
|
||||
self.eps = 1e-100
|
||||
|
||||
def forward(self, state):
|
||||
if self.forward:
|
||||
fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True)
|
||||
b, f, h, w, c = fftim.shape
|
||||
fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w)
|
||||
# Normalize across spatial dimension
|
||||
mean = torch.mean(fftim, dim=(0,1))
|
||||
fftim = fftim - mean
|
||||
std = torch.std(fftim, dim=(0,1))
|
||||
fftim = (fftim + self.eps) / std
|
||||
return {self.output: fftim,
|
||||
'%s_std' % (self.output,): std,
|
||||
'%s_mean' % (self.output,): mean}
|
||||
else:
|
||||
b, f, h, w = state[self.input].shape
|
||||
# First, de-normalize the FFT.
|
||||
mean = state['%s_mean' % (self.input,)]
|
||||
std = state['%s_std' % (self.input,)]
|
||||
fftim = state[self.input] * std + mean - self.eps
|
||||
# Second, recover the FFT dimensions from the given filters.
|
||||
fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2)
|
||||
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
|
||||
return {self.output: im}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user