forked from mrq/DL-Art-School
Add FFT injector
This commit is contained in:
parent
1ef559d7ca
commit
d1175f0de1
|
@ -38,6 +38,8 @@ def create_injector(opt_inject, env):
|
||||||
return ForEachInjector(opt_inject, env)
|
return ForEachInjector(opt_inject, env)
|
||||||
elif type == 'constant':
|
elif type == 'constant':
|
||||||
return ConstantInjector(opt_inject, env)
|
return ConstantInjector(opt_inject, env)
|
||||||
|
elif type == 'fft':
|
||||||
|
return ImageFftInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -262,3 +264,35 @@ class ConstantInjector(Injector):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return { self.opt['out']: out }
|
return { self.opt['out']: out }
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFftInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ImageFftInjector, self).__init__(opt, env)
|
||||||
|
self.is_forward = opt['forward'] # Whether to compute a forward FFT or backward.
|
||||||
|
self.eps = 1e-100
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
if self.forward:
|
||||||
|
fftim = torch.rfft(state[self.input], signal_ndim=2, normalized=True)
|
||||||
|
b, f, h, w, c = fftim.shape
|
||||||
|
fftim = fftim.permute(0,1,4,2,3).reshape(b,-1,h,w)
|
||||||
|
# Normalize across spatial dimension
|
||||||
|
mean = torch.mean(fftim, dim=(0,1))
|
||||||
|
fftim = fftim - mean
|
||||||
|
std = torch.std(fftim, dim=(0,1))
|
||||||
|
fftim = (fftim + self.eps) / std
|
||||||
|
return {self.output: fftim,
|
||||||
|
'%s_std' % (self.output,): std,
|
||||||
|
'%s_mean' % (self.output,): mean}
|
||||||
|
else:
|
||||||
|
b, f, h, w = state[self.input].shape
|
||||||
|
# First, de-normalize the FFT.
|
||||||
|
mean = state['%s_mean' % (self.input,)]
|
||||||
|
std = state['%s_std' % (self.input,)]
|
||||||
|
fftim = state[self.input] * std + mean - self.eps
|
||||||
|
# Second, recover the FFT dimensions from the given filters.
|
||||||
|
fftim = fftim.reshape(b, f // 2, 2, h, w).permute(0,1,3,4,2)
|
||||||
|
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
|
||||||
|
return {self.output: im}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user