diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index e65a1978..7d0584cc 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -32,6 +32,8 @@ def create_injector(opt_inject, env): return ImagePatchInjector(opt_inject, env) elif type == 'concatenate': return ConcatenateInjector(opt_inject, env) + elif type == 'margin_removal': + return MarginRemoval(opt_inject, env) else: raise NotImplementedError @@ -203,4 +205,15 @@ class ConcatenateInjector(Injector): def forward(self, state): input = [state[i] for i in self.input] - return {self.opt['out']: torch.cat(input, dim=self.dim)} \ No newline at end of file + return {self.opt['out']: torch.cat(input, dim=self.dim)} + + +# Removes margins from an image. +class MarginRemoval(Injector): + def __init__(self, opt, env): + super(MarginRemoval, self).__init__(opt, env) + self.margin = opt['margin'] + + def forward(self, state): + input = state[self.input] + return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} \ No newline at end of file