Fix greyscale injector

This commit is contained in:
James Betker 2020-09-02 10:29:40 -06:00
parent 8b52d46847
commit d90c96e55e

View File

@ -101,5 +101,5 @@ class GreyInjector(Injector):
def forward(self, state):
mean = torch.mean(state[self.opt['in']], dim=1, keepdim=True)
mean = torch.repeat(mean, (-1, 3, -1, -1))
mean = mean.repeat((1, 3, 1, 1))
return {self.opt['out']: mean}