From 68e9db12b51aa31d371c2d5dfbf6857fdf50b247 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 2 Dec 2021 21:04:49 -0700 Subject: [PATCH] Add interleaving and direct injectors --- codes/trainer/injectors/base_injectors.py | 34 +++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index f13e7608..10a4d08c 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -13,6 +13,40 @@ from utils.util import opt_get from utils.weight_scheduler import get_scheduler_for_opt +# Transfers the state in the input key to the output key +class DirectInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + + def forward(self, state): + return {self.output: state[self.input]} + + +# Allows multiple injectors to be used on sequential steps. +class StepInterleaveInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + for inj in opt['subinjectors'].keys(): + o = opt.copy() + o['subinjectors'] = opt['subtype'] + o['in'] = '_in' + o['out'] = '_out' + self.injector = create_injector(o, self.env) + self.aslist = opt['aslist'] if 'aslist' in opt.keys() else False + + def forward(self, state): + injs = [] + st = state.copy() + inputs = state[self.opt['in']] + for i in range(inputs.shape[1]): + st['_in'] = inputs[:, i] + injs.append(self.injector(st)['_out']) + if self.aslist: + return {self.output: injs} + else: + return {self.output: torch.stack(injs, dim=1)} + + class PadInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env)