From 365813bde356fe2fd2f02edd17860bda0bbea026 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 3 Sep 2020 11:32:47 -0600 Subject: [PATCH] Add InterpolateInjector --- codes/models/steps/injectors.py | 14 +++++++++++++- codes/models/steps/losses.py | 3 +-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 1238eb88..c0b3fb9c 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -15,6 +15,8 @@ def create_injector(opt_inject, env): return AddNoiseInjector(opt_inject, env) elif type == 'greyscale': return GreyInjector(opt_inject, env) + elif type == 'interpolate': + return InterpolateInjector(opt_inject, env) else: raise NotImplementedError @@ -101,5 +103,15 @@ class GreyInjector(Injector): def forward(self, state): mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True) - mean = mean.repeat((1, 3, 1, 1)) + mean = mean.repeat(1, 3, 1, 1) return {self.opt['out']: mean} + + +class InterpolateInjector(Injector): + def __init__(self, opt, env): + super(InterpolateInjector, self).__init__(opt, env) + + def forward(self, state): + scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'], + mode=self.opt['mode']) + return {self.opt['out']: scaled} \ No newline at end of file diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 8b648be2..77c53581 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -86,8 +86,7 @@ class InterpretedFeatureLoss(ConfigurableLoss): self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen) def forward(self, net, state): - with torch.no_grad(): - logits_real = self.netF_real(state[self.opt['real']]) + logits_real = self.netF_real(state[self.opt['real']]) logits_fake = self.netF_gen(state[self.opt['fake']]) return self.criterion(logits_fake, logits_real)