Fix constant injector - wasn't working in test

This commit is contained in:
James Betker 2020-10-12 10:36:30 -06:00
parent e7cf337dba
commit d7d7590f3e

View File

@ -246,14 +246,12 @@ class ConstantInjector(Injector):
def __init__(self, opt, env): def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env) super(ConstantInjector, self).__init__(opt, env)
self.constant_type = opt['constant_type'] self.constant_type = opt['constant_type']
self.dim = opt['dim']
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use. self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
def forward(self, state): def forward(self, state):
bs = state[self.like].shape[0] like = state[self.like]
dev = state[self.like].device
if self.constant_type == 'zeroes': if self.constant_type == 'zeroes':
out = torch.zeros((bs,) + tuple(self.dim), device=dev) out = torch.zeros_like(like)
else: else:
raise NotImplementedError raise NotImplementedError
return { self.opt['out']: out } return { self.opt['out']: out }