diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 63ea9325..7a09cfc1 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -189,26 +189,28 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_ def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin): - # Combine everything and feed it into the flow network at once for better efficiency. - batch_sz = input_list.shape[0] - flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)] - flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)] - flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0)) - flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0)) - sexts = [] - for i in range(total): - flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)] - flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)] - mid = input_list[:,i+1] - sext = torch.stack([input_list[:,i], mid, input_list[:,i+2], - resampler(input_list[:,i], flow_backward), - mid, - resampler(input_list[:,i+2], flow_forward)], dim=1) - # Apply margin - b, f, c, h, w = sext.shape - sext = sext.view(b, 3*6, h, w) # f*c = 6*3 - sext = sext[:, :, margin:-margin, margin:-margin] - sexts.append(sext) + with autocast(enabled=False): + input_list = input_list.float() + # Combine everything and feed it into the flow network at once for better efficiency. + batch_sz = input_list.shape[0] + flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)] + flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)] + flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0)) + flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0)) + sexts = [] + for i in range(total): + flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)] + flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)] + mid = input_list[:,i+1] + sext = torch.stack([input_list[:,i], mid, input_list[:,i+2], + resampler(input_list[:,i], flow_backward), + mid, + resampler(input_list[:,i+2], flow_forward)], dim=1) + # Apply margin + b, f, c, h, w = sext.shape + sext = sext.view(b, 3*6, h, w) # f*c = 6*3 + sext = sext[:, :, margin:-margin, margin:-margin] + sexts.append(sext) return torch.cat(sexts, dim=0)