Finish integration with autocast
Note: autocast is broken when also using checkpoint(). Overcome this by modifying torch's checkpoint() function in place to also use autocast.
This commit is contained in:
parent
d7ee14f721
commit
15e00e9014
|
@ -1,4 +1,6 @@
|
|||
import torch.nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.SPSR_arch import ImageGradientNoPadding
|
||||
from utils.weight_scheduler import get_scheduler_for_opt
|
||||
from models.steps.losses import extract_params_from_state
|
||||
|
@ -65,11 +67,12 @@ class ImageGeneratorInjector(Injector):
|
|||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
if isinstance(self.input, list):
|
||||
params = extract_params_from_state(self.input, state)
|
||||
results = gen(*params)
|
||||
else:
|
||||
results = gen(state[self.input])
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if isinstance(self.input, list):
|
||||
params = extract_params_from_state(self.input, state)
|
||||
results = gen(*params)
|
||||
else:
|
||||
results = gen(state[self.input])
|
||||
new_state = {}
|
||||
if isinstance(self.output, list):
|
||||
# Only dereference tuples or lists, not tensors.
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.networks import define_F
|
||||
from models.loss import GANLoss
|
||||
import random
|
||||
|
@ -164,20 +166,21 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
nfake.append(fake[i])
|
||||
real = nreal
|
||||
fake = nfake
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||
pred_g_fake = netD(*fake)
|
||||
loss = self.criterion(pred_g_fake, True)
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
pred_d_real = netD(*real)
|
||||
if self.detach_real:
|
||||
pred_d_real = pred_d_real.detach()
|
||||
pred_g_fake = netD(*fake)
|
||||
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
d_fake_diff) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||
pred_g_fake = netD(*fake)
|
||||
loss = self.criterion(pred_g_fake, True)
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
pred_d_real = netD(*real)
|
||||
if self.detach_real:
|
||||
pred_d_real = pred_d_real.detach()
|
||||
pred_g_fake = netD(*fake)
|
||||
d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||
d_fake_diff) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if self.min_loss != 0:
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
|
@ -219,8 +222,9 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
nfake.append(fake[i])
|
||||
real = nreal
|
||||
fake = nfake
|
||||
d_real = net(*real)
|
||||
d_fake = net(*fake)
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
d_real = net(*real)
|
||||
d_fake = net(*fake)
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
|
@ -279,11 +283,13 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss):
|
|||
altered.append(alteration(t))
|
||||
else:
|
||||
altered.append(t)
|
||||
if self.detach_fake:
|
||||
with torch.no_grad():
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if self.detach_fake:
|
||||
with torch.no_grad():
|
||||
upsampled_altered = net(*altered)
|
||||
else:
|
||||
upsampled_altered = net(*altered)
|
||||
else:
|
||||
upsampled_altered = net(*altered)
|
||||
|
||||
if self.gen_output_to_use is not None:
|
||||
upsampled_altered = upsampled_altered[self.gen_output_to_use]
|
||||
|
@ -327,11 +333,14 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
fake = self.opt['fake'].copy()
|
||||
fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name)
|
||||
input = extract_params_from_state(fake, state)
|
||||
if self.detach_fake:
|
||||
with torch.no_grad():
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if self.detach_fake:
|
||||
with torch.no_grad():
|
||||
trans_output = net(*input)
|
||||
else:
|
||||
trans_output = net(*input)
|
||||
else:
|
||||
trans_output = net(*input)
|
||||
|
||||
if self.gen_output_to_use is not None:
|
||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||
else:
|
||||
|
@ -375,7 +384,8 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
|
|||
input = extract_params_from_state(fake, state)
|
||||
for i in range(self.recursive_depth):
|
||||
input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest")
|
||||
recurrent_gen_output = net(*input)[self.gen_output_to_use]
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
recurrent_gen_output = net(*input)[self.gen_output_to_use]
|
||||
|
||||
compare_real = gen_output
|
||||
compare_fake = recurrent_gen_output
|
||||
|
|
|
@ -3,6 +3,7 @@ import random
|
|||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from data.multiscale_dataset import build_multiscale_patch_index_map
|
||||
from models.steps.injectors import Injector
|
||||
|
@ -52,7 +53,10 @@ class ProgressiveGeneratorInjector(Injector):
|
|||
ff_input = inputs.copy()
|
||||
ff_input[self.input_lq_index] = lq_input
|
||||
ff_input[self.recurrent_index] = recurrent_input
|
||||
gen_out = gen(*ff_input)
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
gen_out = gen(*ff_input)
|
||||
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
|
|
|
@ -25,6 +25,7 @@ class ConfigurableStep(Module):
|
|||
self.loss_accumulator = LossAccumulator()
|
||||
self.optimizers = None
|
||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||
self.grads_generated = False
|
||||
|
||||
self.injectors = []
|
||||
if 'injectors' in self.step_opt.keys():
|
||||
|
@ -126,21 +127,20 @@ class ConfigurableStep(Module):
|
|||
self.env['training'] = train
|
||||
|
||||
# Inject in any extra dependencies.
|
||||
with autocast(enabled=self.opt['fp16']):
|
||||
for inj in self.injectors:
|
||||
# Don't do injections tagged with eval unless we are not in train mode.
|
||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||
continue
|
||||
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||
continue
|
||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
for inj in self.injectors:
|
||||
# Don't do injections tagged with eval unless we are not in train mode.
|
||||
if train and 'eval' in inj.opt.keys() and inj.opt['eval']:
|
||||
continue
|
||||
# Likewise, don't do injections tagged with train unless we are not in eval.
|
||||
if not train and 'train' in inj.opt.keys() and inj.opt['train']:
|
||||
continue
|
||||
# Don't do injections tagged with 'after' or 'before' when we are out of spec.
|
||||
if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \
|
||||
'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
||||
if train and len(self.losses) > 0:
|
||||
# Finally, compute the losses.
|
||||
|
@ -150,7 +150,6 @@ class ConfigurableStep(Module):
|
|||
# be very disruptive to a generator.
|
||||
if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']:
|
||||
continue
|
||||
|
||||
l = loss(self.training_net, local_state)
|
||||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
|
@ -167,9 +166,8 @@ class ConfigurableStep(Module):
|
|||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
# Workaround for https://github.com/pytorch/pytorch/issues/37730
|
||||
with autocast():
|
||||
self.scaler.scale(total_loss).backward()
|
||||
self.scaler.scale(total_loss).backward()
|
||||
self.grads_generated = True
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
|
@ -179,6 +177,9 @@ class ConfigurableStep(Module):
|
|||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||
# all self.optimizers.
|
||||
def do_step(self):
|
||||
if not self.grads_generated:
|
||||
return
|
||||
self.grads_generated = False
|
||||
for opt in self.optimizers:
|
||||
# Optimizers can be opted out in the early stages of training.
|
||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.steps.injectors import Injector
|
||||
|
@ -24,10 +26,10 @@ def create_teco_injector(opt, env):
|
|||
return FlowAdjustment(opt, env)
|
||||
return None
|
||||
|
||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16):
|
||||
triplet = input_list[:, index:index+3]
|
||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||
with torch.no_grad():
|
||||
with torch.no_grad() and autocast(enabled=fp16):
|
||||
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
|
||||
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
|
||||
|
@ -99,14 +101,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
# Resample does not work in FP16.
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
gen_out = gen(*input)
|
||||
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
|
@ -121,14 +127,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(recurrent_input.float(), flowfield.float())
|
||||
input[self.recurrent_index
|
||||
] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
gen_out = gen(*input)
|
||||
|
||||
if isinstance(gen_out, torch.Tensor):
|
||||
gen_out = [gen_out]
|
||||
for i, out_key in enumerate(self.output):
|
||||
|
@ -192,6 +202,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||
|
||||
def forward(self, _, state):
|
||||
fp16 = self.env['opt']['fp16']
|
||||
net = self.env['discriminators'][self.opt['discriminator']]
|
||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||
real = state[self.opt['real']]
|
||||
|
@ -200,10 +211,11 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
lr = state[self.opt['lr_inputs']]
|
||||
l_total = 0
|
||||
for i in range(sequence_len - 2):
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
d_fake = net(fake_sext)
|
||||
d_real = net(real_sext)
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16)
|
||||
with autocast(enabled=fp16):
|
||||
d_fake = net(fake_sext)
|
||||
d_real = net(real_sext)
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user