Fix fast_forward teco loss bug

This commit is contained in:
James Betker 2020-10-28 17:49:54 -06:00
parent 25b007a0f5
commit 1655b9e242

View File

@ -201,9 +201,9 @@ def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_
flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)]
mid = input_list[:,i+1]
sext = torch.stack([input_list[:,i], mid, input_list[:,i+2],
resampler(mid, flow_backward),
resampler(input_list[:,i], flow_backward),
mid,
resampler(mid, flow_forward)], dim=1)
resampler(input_list[:,i+2], flow_forward)], dim=1)
# Apply margin
b, f, c, h, w = sext.shape
sext = sext.view(b, 3*6, h, w) # f*c = 6*3