Fix constant injector - wasn't working in test

This commit is contained in:
James Betker 2020-10-12 10:36:30 -06:00
parent e7cf337dba
commit d7d7590f3e

View File

@ -246,14 +246,12 @@ class ConstantInjector(Injector):
def __init__(self, opt, env):
super(ConstantInjector, self).__init__(opt, env)
self.constant_type = opt['constant_type']
self.dim = opt['dim']
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
def forward(self, state):
bs = state[self.like].shape[0]
dev = state[self.like].device
like = state[self.like]
if self.constant_type == 'zeroes':
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
out = torch.zeros_like(like)
else:
raise NotImplementedError
return { self.opt['out']: out }