Add MarginRemoval injector
This commit is contained in:
parent
0011d445c8
commit
0d30d18a3d
|
@ -32,6 +32,8 @@ def create_injector(opt_inject, env):
|
||||||
return ImagePatchInjector(opt_inject, env)
|
return ImagePatchInjector(opt_inject, env)
|
||||||
elif type == 'concatenate':
|
elif type == 'concatenate':
|
||||||
return ConcatenateInjector(opt_inject, env)
|
return ConcatenateInjector(opt_inject, env)
|
||||||
|
elif type == 'margin_removal':
|
||||||
|
return MarginRemoval(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -203,4 +205,15 @@ class ConcatenateInjector(Injector):
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
input = [state[i] for i in self.input]
|
input = [state[i] for i in self.input]
|
||||||
return {self.opt['out']: torch.cat(input, dim=self.dim)}
|
return {self.opt['out']: torch.cat(input, dim=self.dim)}
|
||||||
|
|
||||||
|
|
||||||
|
# Removes margins from an image.
|
||||||
|
class MarginRemoval(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(MarginRemoval, self).__init__(opt, env)
|
||||||
|
self.margin = opt['margin']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
input = state[self.input]
|
||||||
|
return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]}
|
Loading…
Reference in New Issue
Block a user