Add MarginRemoval injector

This commit is contained in:
James Betker 2020-10-09 20:35:56 -06:00
parent 0011d445c8
commit 0d30d18a3d

View File

@ -32,6 +32,8 @@ def create_injector(opt_inject, env):
return ImagePatchInjector(opt_inject, env)
elif type == 'concatenate':
return ConcatenateInjector(opt_inject, env)
elif type == 'margin_removal':
return MarginRemoval(opt_inject, env)
else:
raise NotImplementedError
@ -203,4 +205,15 @@ class ConcatenateInjector(Injector):
def forward(self, state):
input = [state[i] for i in self.input]
return {self.opt['out']: torch.cat(input, dim=self.dim)}
return {self.opt['out']: torch.cat(input, dim=self.dim)}
# Removes margins from an image.
class MarginRemoval(Injector):
def __init__(self, opt, env):
super(MarginRemoval, self).__init__(opt, env)
self.margin = opt['margin']
def forward(self, state):
input = state[self.input]
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}