Fix constant injector - wasn't working in test
This commit is contained in:
parent
e7cf337dba
commit
d7d7590f3e
|
@ -246,14 +246,12 @@ class ConstantInjector(Injector):
|
|||
def __init__(self, opt, env):
|
||||
super(ConstantInjector, self).__init__(opt, env)
|
||||
self.constant_type = opt['constant_type']
|
||||
self.dim = opt['dim']
|
||||
self.like = opt['like'] # This injector uses this tensor to determine what batch size and device to use.
|
||||
|
||||
def forward(self, state):
|
||||
bs = state[self.like].shape[0]
|
||||
dev = state[self.like].device
|
||||
like = state[self.like]
|
||||
if self.constant_type == 'zeroes':
|
||||
out = torch.zeros((bs,) + tuple(self.dim), device=dev)
|
||||
out = torch.zeros_like(like)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return { self.opt['out']: out }
|
||||
|
|
Loading…
Reference in New Issue
Block a user