Allow initial recurrent input to be specified (optionally)
This commit is contained in:
parent
597b6e92d6
commit
05377973bf
|
@ -67,6 +67,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
self.resample = Resample2d()
|
self.resample = Resample2d()
|
||||||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||||
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
||||||
|
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
@ -89,7 +90,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
for i in range(f):
|
for i in range(f):
|
||||||
if first_step:
|
if first_step:
|
||||||
input = extract_inputs_index(first_inputs, i)
|
input = extract_inputs_index(first_inputs, i)
|
||||||
recurrent_input = torch.zeros_like(input[self.recurrent_index])
|
if self.hq_recurrent:
|
||||||
|
recurrent_input = input[self.recurrent_index]
|
||||||
|
else:
|
||||||
|
recurrent_input = torch.zeros_like(input[self.recurrent_index])
|
||||||
first_step = False
|
first_step = False
|
||||||
else:
|
else:
|
||||||
input = extract_inputs_index(inputs, i)
|
input = extract_inputs_index(inputs, i)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user