Add InterpolateInjector

This commit is contained in:
James Betker 2020-09-03 11:32:47 -06:00
parent d90c96e55e
commit 365813bde3
2 changed files with 14 additions and 3 deletions

View File

@ -15,6 +15,8 @@ def create_injector(opt_inject, env):
return AddNoiseInjector(opt_inject, env)
elif type == 'greyscale':
return GreyInjector(opt_inject, env)
elif type == 'interpolate':
return InterpolateInjector(opt_inject, env)
else:
raise NotImplementedError
@ -101,5 +103,15 @@ class GreyInjector(Injector):
def forward(self, state):
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
mean = mean.repeat((1, 3, 1, 1))
mean = mean.repeat(1, 3, 1, 1)
return {self.opt['out']: mean}
class InterpolateInjector(Injector):
def __init__(self, opt, env):
super(InterpolateInjector, self).__init__(opt, env)
def forward(self, state):
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
mode=self.opt['mode'])
return {self.opt['out']: scaled}

View File

@ -86,8 +86,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
def forward(self, net, state):
with torch.no_grad():
logits_real = self.netF_real(state[self.opt['real']])
logits_real = self.netF_real(state[self.opt['real']])
logits_fake = self.netF_gen(state[self.opt['fake']])
return self.criterion(logits_fake, logits_real)