diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 52cac3c9..63ea9325 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -201,9 +201,9 @@ def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_ flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)] mid = input_list[:,i+1] sext = torch.stack([input_list[:,i], mid, input_list[:,i+2], - resampler(mid, flow_backward), + resampler(input_list[:,i], flow_backward), mid, - resampler(mid, flow_forward)], dim=1) + resampler(input_list[:,i+2], flow_forward)], dim=1) # Apply margin b, f, c, h, w = sext.shape sext = sext.view(b, 3*6, h, w) # f*c = 6*3