diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 8effb6bf..1238eb88 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -101,5 +101,5 @@ class GreyInjector(Injector): def forward(self, state): mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True) - mean = torch.repeat(mean, (-1, 3, -1, -1)) + mean = mean.repeat((1, 3, 1, 1)) return {self.opt['out']: mean}