Allow initial recurrent input to be specified (optionally)

This commit is contained in:
James Betker 2020-10-12 17:36:43 -06:00
parent 597b6e92d6
commit 05377973bf

View File

@ -67,6 +67,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.resample = Resample2d()
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
@ -89,6 +90,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
for i in range(f):
if first_step:
input = extract_inputs_index(first_inputs, i)
if self.hq_recurrent:
recurrent_input = input[self.recurrent_index]
else:
recurrent_input = torch.zeros_like(input[self.recurrent_index])
first_step = False
else: