From d7d7590f3e5a8bf7ba5ff317443d3a1397cfa39b Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 12 Oct 2020 10:36:30 -0600 Subject: [PATCH] Fix constant injector - wasn't working in test --- codes/models/steps/injectors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 60cb94af..832bb25b 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -246,14 +246,12 @@ class ConstantInjector(Injector): def __init__(self, opt, env): super(ConstantInjector, self).__init__(opt, env) self.constant_type = opt['constant_type'] - self.dim = opt['dim'] self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. def forward(self, state): - bs = state[self.like].shape[0] - dev = state[self.like].device + like = state[self.like] if self.constant_type == 'zeroes': - out = torch.zeros((bs,) + tuple(self.dim), device=dev) + out = torch.zeros_like(like) else: raise NotImplementedError return { self.opt['out']: out }