From 0d30d18a3d9cdaa2313633ba8a17cca0f3128626 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 9 Oct 2020 20:35:56 -0600 Subject: [PATCH] Add MarginRemoval injector --- codes/models/steps/injectors.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index e65a1978..7d0584cc 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -32,6 +32,8 @@ def create_injector(opt_inject, env): return ImagePatchInjector(opt_inject, env) elif type == 'concatenate': return ConcatenateInjector(opt_inject, env) + elif type == 'margin_removal': + return MarginRemoval(opt_inject, env) else: raise NotImplementedError @@ -203,4 +205,15 @@ class ConcatenateInjector(Injector): def forward(self, state): input = [state[i] for i in self.input] - return {self.opt['out']: torch.cat(input, dim=self.dim)} \ No newline at end of file + return {self.opt['out']: torch.cat(input, dim=self.dim)} + + +# Removes margins from an image. +class MarginRemoval(Injector): + def __init__(self, opt, env): + super(MarginRemoval, self).__init__(opt, env) + self.margin = opt['margin'] + + def forward(self, state): + input = state[self.input] + return {self.opt['out']: input[:, :, self.margin:-self.margin, self.margin:-self.margin]} \ No newline at end of file