Add InterpolateInjector
This commit is contained in:
parent
d90c96e55e
commit
365813bde3
|
@ -15,6 +15,8 @@ def create_injector(opt_inject, env):
|
|||
return AddNoiseInjector(opt_inject, env)
|
||||
elif type == 'greyscale':
|
||||
return GreyInjector(opt_inject, env)
|
||||
elif type == 'interpolate':
|
||||
return InterpolateInjector(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -101,5 +103,15 @@ class GreyInjector(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
||||
mean = mean.repeat((1, 3, 1, 1))
|
||||
mean = mean.repeat(1, 3, 1, 1)
|
||||
return {self.opt['out']: mean}
|
||||
|
||||
|
||||
class InterpolateInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(InterpolateInjector, self).__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
||||
mode=self.opt['mode'])
|
||||
return {self.opt['out']: scaled}
|
|
@ -86,7 +86,6 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
|||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||
|
||||
def forward(self, net, state):
|
||||
with torch.no_grad():
|
||||
logits_real = self.netF_real(state[self.opt['real']])
|
||||
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||
return self.criterion(logits_fake, logits_real)
|
||||
|
|
Loading…
Reference in New Issue
Block a user