Add interleaving and direct injectors

This commit is contained in:
James Betker 2021-12-02 21:04:49 -07:00
parent 04454ee63a
commit 68e9db12b5

View File

@ -13,6 +13,40 @@ from utils.util import opt_get
from utils.weight_scheduler import get_scheduler_for_opt
# Transfers the state in the input key to the output key
class DirectInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
def forward(self, state):
return {self.output: state[self.input]}
# Allows multiple injectors to be used on sequential steps.
class StepInterleaveInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
for inj in opt['subinjectors'].keys():
o = opt.copy()
o['subinjectors'] = opt['subtype']
o['in'] = '_in'
o['out'] = '_out'
self.injector = create_injector(o, self.env)
self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False
def forward(self, state):
injs = []
st = state.copy()
inputs = state[self.opt['in']]
for i in range(inputs.shape[1]):
st['_in'] = inputs[:, i]
injs.append(self.injector(st)['_out'])
if self.aslist:
return {self.output: injs}
else:
return {self.output: torch.stack(injs, dim=1)}
class PadInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)