Fix up translational equivariance loss so it's ready for prime time
This commit is contained in:
parent
896b4f5be2
commit
66d4512029
|
@ -145,10 +145,16 @@ class GreyInjector(Injector):
|
|||
class InterpolateInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(InterpolateInjector, self).__init__(opt, env)
|
||||
if 'scale_factor' in opt.keys():
|
||||
self.scale_factor = opt['scale_factor']
|
||||
self.size = None
|
||||
else:
|
||||
self.scale_factor = None
|
||||
self.size = (opt['size'], opt['size'])
|
||||
|
||||
def forward(self, state):
|
||||
scaled = torch.nn.functional.interpolate(state[self.opt['in']], scale_factor=self.opt['scale_factor'],
|
||||
mode=self.opt['mode'])
|
||||
size=self.opt['size'], mode=self.opt['mode'])
|
||||
return {self.opt['out']: scaled}
|
||||
|
||||
|
||||
|
@ -171,11 +177,11 @@ class ImagePatchInjector(Injector):
|
|||
def forward(self, state):
|
||||
im = state[self.opt['in']]
|
||||
if self.env['training']:
|
||||
return { self.opt['out']: im[:, :self.patch_size, :self.patch_size],
|
||||
'%s_top_left' % (self.opt['out'],): im[:, :self.patch_size, :self.patch_size],
|
||||
'%s_top_right' % (self.opt['out'],): im[:, :self.patch_size, -self.patch_size:],
|
||||
'%s_bottom_left' % (self.opt['out'],): im[:, -self.patch_size:, :self.patch_size],
|
||||
'%s_bottom_right' % (self.opt['out'],): im[:, -self.patch_size:, -self.patch_size:] }
|
||||
return { self.opt['out']: im[:, :3, :self.patch_size, :self.patch_size],
|
||||
'%s_top_left' % (self.opt['out'],): im[:, :, :self.patch_size, :self.patch_size],
|
||||
'%s_top_right' % (self.opt['out'],): im[:, :, :self.patch_size, -self.patch_size:],
|
||||
'%s_bottom_left' % (self.opt['out'],): im[:, :, -self.patch_size:, :self.patch_size],
|
||||
'%s_bottom_right' % (self.opt['out'],): im[:, :, -self.patch_size:, -self.patch_size:] }
|
||||
else:
|
||||
return { self.opt['out']: im,
|
||||
'%s_top_left' % (self.opt['out'],): im,
|
||||
|
|
|
@ -258,6 +258,7 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
self.gen_output_to_use = opt['generator_output_index'] if 'generator_output_index' in opt.keys() else None
|
||||
self.patch_size = opt['patch_size']
|
||||
self.overlap = opt['overlap'] # For maximum overlap, can be calculated as 2*patch_size-image_size
|
||||
self.detach_fake = opt['detach_fake']
|
||||
assert(self.patch_size > self.overlap)
|
||||
|
||||
def forward(self, net, state):
|
||||
|
@ -271,15 +272,19 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
("bottom_right", 0, self.overlap, 0, self.overlap)])
|
||||
trans_name, hl, hh, wl, wh = translation
|
||||
# Change the "fake" input name that we are translating to one that specifies the random translation.
|
||||
self.opt['fake'][self.gen_input_for_alteration] = "%s_%s" % (self.opt['fake'], trans_name)
|
||||
input = extract_params_from_state(self.opt['fake'], state)
|
||||
with torch.no_grad():
|
||||
fake = self.opt['fake'].copy()
|
||||
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
|
||||
input = extract_params_from_state(fake, state)
|
||||
if self.detach_fake:
|
||||
with torch.no_grad():
|
||||
trans_output = net(*input)
|
||||
else:
|
||||
trans_output = net(*input)
|
||||
fake_shared_output = trans_output[:, hl:hh, wl:wh][self.gen_output_to_use]
|
||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||
|
||||
# The "real" input is assumed to always come from the top left tile.
|
||||
gen_output = state[self.opt['real']]
|
||||
real_shared_output = gen_output[:, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap][self.gen_output_to_use]
|
||||
real_shared_output = gen_output[:, :, border_sz:border_sz+self.overlap, border_sz:border_sz+self.overlap]
|
||||
|
||||
return self.criterion(fake_shared_output, real_shared_output)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user