Allow initial recurrent input to be specified (optionally)
This commit is contained in:
parent
597b6e92d6
commit
05377973bf
|
@ -67,6 +67,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
self.resample = Resample2d()
|
||||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
||||
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
|
@ -89,7 +90,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
for i in range(f):
|
||||
if first_step:
|
||||
input = extract_inputs_index(first_inputs, i)
|
||||
recurrent_input = torch.zeros_like(input[self.recurrent_index])
|
||||
if self.hq_recurrent:
|
||||
recurrent_input = input[self.recurrent_index]
|
||||
else:
|
||||
recurrent_input = torch.zeros_like(input[self.recurrent_index])
|
||||
first_step = False
|
||||
else:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
|
|
Loading…
Reference in New Issue
Block a user