Fix constant injector - wasn't working in test
This commit is contained in:
parent
e7cf337dba
commit
d7d7590f3e
|
@ -246,14 +246,12 @@ class ConstantInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(ConstantInjector, self).__init__(opt, env)
|
super(ConstantInjector, self).__init__(opt, env)
|
||||||
self.constant_type = opt['constant_type']
|
self.constant_type = opt['constant_type']
|
||||||
self.dim = opt['dim']
|
|
||||||
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
bs = state[self.like].shape[0]
|
like = state[self.like]
|
||||||
dev = state[self.like].device
|
|
||||||
if self.constant_type == 'zeroes':
|
if self.constant_type == 'zeroes':
|
||||||
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
out = torch.zeros_like(like)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return { self.opt['out']: out }
|
return { self.opt['out']: out }
|
||||||
|
|
Loading…
Reference in New Issue
Block a user