From 66d45120295012467513128ad554196940234748 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 30 Sep 2020 12:01:00 -0600 Subject: [PATCH] Fix up translational equivariance loss so it's ready for prime time --- codes/models/steps/injectors.py | 18 ++++++++++++------ codes/models/steps/losses.py | 15 ++++++++++----- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 74ad5a72..421d44cc 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -145,10 +145,16 @@ class GreyInjector(Injector): class InterpolateInjector(Injector): def __init__(self, opt, env): super(InterpolateInjector, self).__init__(opt, env) + if 'scale_factor' in opt.keys(): + self.scale_factor = opt['scale_factor'] + self.size = None + else: + self.scale_factor = None + self.size = (opt['size'], opt['size']) def forward(self, state): scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'], - mode=self.opt['mode']) + size=self.opt['size'], mode=self.opt['mode']) return {self.opt['out']: scaled} @@ -171,11 +177,11 @@ class ImagePatchInjector(Injector): def forward(self, state): im = state[self.opt['in']] if self.env['training']: - return { self.opt['out']: im[:, :self.patch_size, :self.patch_size], - '%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size], - '%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:], - '%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size], - '%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] } + return { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size], + '%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size], + '%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:], + '%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size], + '%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] } else: return { self.opt['out']: im, '%s_top_left' % (self.opt['out'],): im, diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index afecd82b..7e6f7c30 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -258,6 +258,7 @@ class TranslationInvarianceLoss(ConfigurableLoss): self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None self.patch_size = opt['patch_size'] self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size + self.detach_fake = opt['detach_fake'] assert(self.patch_size > self.overlap) def forward(self, net, state): @@ -271,15 +272,19 @@ class TranslationInvarianceLoss(ConfigurableLoss): ("bottom_right", 0, self.overlap, 0, self.overlap)]) trans_name, hl, hh, wl, wh = translation # Change the "fake" input name that we are translating to one that specifies the random translation. - self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name) - input = extract_params_from_state(self.opt['fake'], state) - with torch.no_grad(): + fake = self.opt['fake'].copy() + fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name) + input = extract_params_from_state(fake, state) + if self.detach_fake: + with torch.no_grad(): + trans_output = net(*input) + else: trans_output = net(*input) - fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use] + fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] # The "real" input is assumed to always come from the top left tile. gen_output = state[self.opt['real']] - real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use] + real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap] return self.criterion(fake_shared_output, real_shared_output)