Add interleaving and direct injectors
This commit is contained in:
parent
04454ee63a
commit
68e9db12b5
|
@ -13,6 +13,40 @@ from utils.util import opt_get
|
|||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
|
||||
|
||||
# Transfers the state in the input key to the output key
|
||||
class DirectInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
|
||||
def forward(self, state):
|
||||
return {self.output: state[self.input]}
|
||||
|
||||
|
||||
# Allows multiple injectors to be used on sequential steps.
|
||||
class StepInterleaveInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
for inj in opt['subinjectors'].keys():
|
||||
o = opt.copy()
|
||||
o['subinjectors'] = opt['subtype']
|
||||
o['in'] = '_in'
|
||||
o['out'] = '_out'
|
||||
self.injector = create_injector(o, self.env)
|
||||
self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False
|
||||
|
||||
def forward(self, state):
|
||||
injs = []
|
||||
st = state.copy()
|
||||
inputs = state[self.opt['in']]
|
||||
for i in range(inputs.shape[1]):
|
||||
st['_in'] = inputs[:, i]
|
||||
injs.append(self.injector(st)['_out'])
|
||||
if self.aslist:
|
||||
return {self.output: injs}
|
||||
else:
|
||||
return {self.output: torch.stack(injs, dim=1)}
|
||||
|
||||
|
||||
class PadInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
|
|
Loading…
Reference in New Issue
Block a user