Fix up translational equivariance loss so it's ready for prime time

This commit is contained in:
James Betker 2020-09-30 12:01:00 -06:00
parent 896b4f5be2
commit 66d4512029
2 changed files with 22 additions and 11 deletions

View File

@ -145,10 +145,16 @@ class GreyInjector(Injector):
class InterpolateInjector(Injector):
def __init__(self, opt, env):
super(InterpolateInjector, self).__init__(opt, env)
if 'scale_factor' in opt.keys():
self.scale_factor = opt['scale_factor']
self.size = None
else:
self.scale_factor = None
self.size = (opt['size'], opt['size'])
def forward(self, state):
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
mode=self.opt['mode'])
size=self.opt['size'], mode=self.opt['mode'])
return {self.opt['out']: scaled}
@ -171,11 +177,11 @@ class ImagePatchInjector(Injector):
def forward(self, state):
im = state[self.opt['in']]
if self.env['training']:
return { self.opt['out']: im[:, :self.patch_size, :self.patch_size],
'%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size],
'%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:],
'%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size],
'%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] }
return { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] }
else:
return { self.opt['out']: im,
'%s_top_left' % (self.opt['out'],): im,

View File

@ -258,6 +258,7 @@ class TranslationInvarianceLoss(ConfigurableLoss):
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
self.patch_size = opt['patch_size']
self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size
self.detach_fake = opt['detach_fake']
assert(self.patch_size > self.overlap)
def forward(self, net, state):
@ -271,15 +272,19 @@ class TranslationInvarianceLoss(ConfigurableLoss):
("bottom_right", 0, self.overlap, 0, self.overlap)])
trans_name, hl, hh, wl, wh = translation
# Change the "fake" input name that we are translating to one that specifies the random translation.
self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name)
input = extract_params_from_state(self.opt['fake'], state)
with torch.no_grad():
fake = self.opt['fake'].copy()
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
input = extract_params_from_state(fake, state)
if self.detach_fake:
with torch.no_grad():
trans_output = net(*input)
else:
trans_output = net(*input)
fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use]
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
# The "real" input is assumed to always come from the top left tile.
gen_output = state[self.opt['real']]
real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use]
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
return self.criterion(fake_shared_output, real_shared_output)