diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 60cb94af..832bb25b 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -246,14 +246,12 @@ class ConstantInjector(Injector): def __init__(self, opt, env): super(ConstantInjector, self).__init__(opt, env) self.constant_type = opt['constant_type'] - self.dim = opt['dim'] self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. def forward(self, state): - bs = state[self.like].shape[0] - dev = state[self.like].device + like = state[self.like] if self.constant_type == 'zeroes': - out = torch.zeros((bs,) + tuple(self.dim), device=dev) + out = torch.zeros_like(like) else: raise NotImplementedError return { self.opt['out']: out }