Add MarginRemoval injector
This commit is contained in:
parent
0011d445c8
commit
0d30d18a3d
|
@ -32,6 +32,8 @@ def create_injector(opt_inject, env):
|
|||
return ImagePatchInjector(opt_inject, env)
|
||||
elif type == 'concatenate':
|
||||
return ConcatenateInjector(opt_inject, env)
|
||||
elif type == 'margin_removal':
|
||||
return MarginRemoval(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -203,4 +205,15 @@ class ConcatenateInjector(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
input = [state[i] for i in self.input]
|
||||
return {self.opt['out']: torch.cat(input, dim=self.dim)}
|
||||
return {self.opt['out']: torch.cat(input, dim=self.dim)}
|
||||
|
||||
|
||||
# Removes margins from an image.
|
||||
class MarginRemoval(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(MarginRemoval, self).__init__(opt, env)
|
||||
self.margin = opt['margin']
|
||||
|
||||
def forward(self, state):
|
||||
input = state[self.input]
|
||||
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
Loading…
Reference in New Issue
Block a user