Allow tecogan losses to compute at 32px
This commit is contained in:
parent
629b968901
commit
f857eb00a8
|
@ -97,9 +97,15 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
else:
|
else:
|
||||||
input = extract_inputs_index(inputs, i)
|
input = extract_inputs_index(inputs, i)
|
||||||
with torch.no_grad() and autocast(enabled=False):
|
with torch.no_grad() and autocast(enabled=False):
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float()
|
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
|
||||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
if self.scale >= 4:
|
||||||
|
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
|
||||||
|
else:
|
||||||
|
flow_input = input[self.input_lq_index]
|
||||||
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
|
||||||
|
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
|
||||||
|
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
|
||||||
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
||||||
input[self.recurrent_index] = recurrent_input
|
input[self.recurrent_index] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
|
@ -122,9 +128,15 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
input = extract_inputs_index(inputs, i)
|
input = extract_inputs_index(inputs, i)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float()
|
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
|
||||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
if self.scale >= 4:
|
||||||
|
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
|
||||||
|
else:
|
||||||
|
flow_input = input[self.input_lq_index]
|
||||||
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
|
||||||
|
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
|
||||||
|
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
|
||||||
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
recurrent_input = self.resample(recurrent_input.float(), flowfield)
|
||||||
input[self.recurrent_index] = recurrent_input
|
input[self.recurrent_index] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user