forked from mrq/DL-Art-School
Add InterpolateInjector
This commit is contained in:
parent
d90c96e55e
commit
365813bde3
|
@ -15,6 +15,8 @@ def create_injector(opt_inject, env):
|
||||||
return AddNoiseInjector(opt_inject, env)
|
return AddNoiseInjector(opt_inject, env)
|
||||||
elif type == 'greyscale':
|
elif type == 'greyscale':
|
||||||
return GreyInjector(opt_inject, env)
|
return GreyInjector(opt_inject, env)
|
||||||
|
elif type == 'interpolate':
|
||||||
|
return InterpolateInjector(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -101,5 +103,15 @@ class GreyInjector(Injector):
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
|
||||||
mean = mean.repeat((1, 3, 1, 1))
|
mean = mean.repeat(1, 3, 1, 1)
|
||||||
return {self.opt['out']: mean}
|
return {self.opt['out']: mean}
|
||||||
|
|
||||||
|
|
||||||
|
class InterpolateInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(InterpolateInjector, self).__init__(opt, env)
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
||||||
|
mode=self.opt['mode'])
|
||||||
|
return {self.opt['out']: scaled}
|
|
@ -86,8 +86,7 @@ class InterpretedFeatureLoss(ConfigurableLoss):
|
||||||
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
self.netF_gen = torch.nn.parallel.DataParallel(self.netF_gen)
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
with torch.no_grad():
|
logits_real = self.netF_real(state[self.opt['real']])
|
||||||
logits_real = self.netF_real(state[self.opt['real']])
|
|
||||||
logits_fake = self.netF_gen(state[self.opt['fake']])
|
logits_fake = self.netF_gen(state[self.opt['fake']])
|
||||||
return self.criterion(logits_fake, logits_real)
|
return self.criterion(logits_fake, logits_real)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user