Fix tecogan_losses fp16
This commit is contained in:
parent
3791f95ad0
commit
b316078a15
|
@ -189,26 +189,28 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_
|
||||||
|
|
||||||
|
|
||||||
def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin):
|
def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin):
|
||||||
# Combine everything and feed it into the flow network at once for better efficiency.
|
with autocast(enabled=False):
|
||||||
batch_sz = input_list.shape[0]
|
input_list = input_list.float()
|
||||||
flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)]
|
# Combine everything and feed it into the flow network at once for better efficiency.
|
||||||
flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)]
|
batch_sz = input_list.shape[0]
|
||||||
flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0))
|
flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)]
|
||||||
flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0))
|
flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)]
|
||||||
sexts = []
|
flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0))
|
||||||
for i in range(total):
|
flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0))
|
||||||
flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)]
|
sexts = []
|
||||||
flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)]
|
for i in range(total):
|
||||||
mid = input_list[:,i+1]
|
flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)]
|
||||||
sext = torch.stack([input_list[:,i], mid, input_list[:,i+2],
|
flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)]
|
||||||
resampler(input_list[:,i], flow_backward),
|
mid = input_list[:,i+1]
|
||||||
mid,
|
sext = torch.stack([input_list[:,i], mid, input_list[:,i+2],
|
||||||
resampler(input_list[:,i+2], flow_forward)], dim=1)
|
resampler(input_list[:,i], flow_backward),
|
||||||
# Apply margin
|
mid,
|
||||||
b, f, c, h, w = sext.shape
|
resampler(input_list[:,i+2], flow_forward)], dim=1)
|
||||||
sext = sext.view(b, 3*6, h, w) # f*c = 6*3
|
# Apply margin
|
||||||
sext = sext[:, :, margin:-margin, margin:-margin]
|
b, f, c, h, w = sext.shape
|
||||||
sexts.append(sext)
|
sext = sext.view(b, 3*6, h, w) # f*c = 6*3
|
||||||
|
sext = sext[:, :, margin:-margin, margin:-margin]
|
||||||
|
sexts.append(sext)
|
||||||
return torch.cat(sexts, dim=0)
|
return torch.cat(sexts, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user