diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index de78aff3..7e6e2fd9 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -1,4 +1,6 @@ import torch.nn +from torch.cuda.amp import autocast + from models.archs.SPSR_arch import ImageGradientNoPadding from utils.weight_scheduler import get_scheduler_for_opt from models.steps.losses import extract_params_from_state @@ -65,11 +67,12 @@ class ImageGeneratorInjector(Injector): def forward(self, state): gen = self.env['generators'][self.opt['generator']] - if isinstance(self.input, list): - params = extract_params_from_state(self.input, state) - results = gen(*params) - else: - results = gen(state[self.input]) + with autocast(enabled=self.env['opt']['fp16']): + if isinstance(self.input, list): + params = extract_params_from_state(self.input, state) + results = gen(*params) + else: + results = gen(state[self.input]) new_state = {} if isinstance(self.output, list): # Only dereference tuples or lists, not tensors. diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 5cd8e174..4641ac05 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -1,5 +1,7 @@ import torch import torch.nn as nn +from torch.cuda.amp import autocast + from models.networks import define_F from models.loss import GANLoss import random @@ -164,20 +166,21 @@ class GeneratorGanLoss(ConfigurableLoss): nfake.append(fake[i]) real = nreal fake = nfake - if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: - pred_g_fake = netD(*fake) - loss = self.criterion(pred_g_fake, True) - elif self.opt['gan_type'] == 'ragan': - pred_d_real = netD(*real) - if self.detach_real: - pred_d_real = pred_d_real.detach() - pred_g_fake = netD(*fake) - d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True) - self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) - loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + - d_fake_diff) / 2 - else: - raise NotImplementedError + with autocast(enabled=self.env['opt']['fp16']): + if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']: + pred_g_fake = netD(*fake) + loss = self.criterion(pred_g_fake, True) + elif self.opt['gan_type'] == 'ragan': + pred_d_real = netD(*real) + if self.detach_real: + pred_d_real = pred_d_real.detach() + pred_g_fake = netD(*fake) + d_fake_diff = self.criterion(pred_g_fake - torch.mean(pred_d_real), True) + self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) + loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) + + d_fake_diff) / 2 + else: + raise NotImplementedError if self.min_loss != 0: self.loss_rotating_buffer[self.rb_ptr] = loss.item() self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] @@ -219,8 +222,9 @@ class DiscriminatorGanLoss(ConfigurableLoss): nfake.append(fake[i]) real = nreal fake = nfake - d_real = net(*real) - d_fake = net(*fake) + with autocast(enabled=self.env['opt']['fp16']): + d_real = net(*real) + d_fake = net(*fake) if self.opt['gan_type'] in ['gan', 'pixgan']: self.metrics.append(("d_fake", torch.mean(d_fake))) @@ -279,11 +283,13 @@ class GeometricSimilarityGeneratorLoss(ConfigurableLoss): altered.append(alteration(t)) else: altered.append(t) - if self.detach_fake: - with torch.no_grad(): + + with autocast(enabled=self.env['opt']['fp16']): + if self.detach_fake: + with torch.no_grad(): + upsampled_altered = net(*altered) + else: upsampled_altered = net(*altered) - else: - upsampled_altered = net(*altered) if self.gen_output_to_use is not None: upsampled_altered = upsampled_altered[self.gen_output_to_use] @@ -327,11 +333,14 @@ class TranslationInvarianceLoss(ConfigurableLoss): fake = self.opt['fake'].copy() fake[self.gen_input_for_alteration] = "%s_%s" % (fake[self.gen_input_for_alteration], trans_name) input = extract_params_from_state(fake, state) - if self.detach_fake: - with torch.no_grad(): + + with autocast(enabled=self.env['opt']['fp16']): + if self.detach_fake: + with torch.no_grad(): + trans_output = net(*input) + else: trans_output = net(*input) - else: - trans_output = net(*input) + if self.gen_output_to_use is not None: fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] else: @@ -375,7 +384,8 @@ class RecursiveInvarianceLoss(ConfigurableLoss): input = extract_params_from_state(fake, state) for i in range(self.recursive_depth): input[self.gen_input_for_alteration] = torch.nn.functional.interpolate(recurrent_gen_output, scale_factor=self.downsample_factor, mode="nearest") - recurrent_gen_output = net(*input)[self.gen_output_to_use] + with autocast(enabled=self.env['opt']['fp16']): + recurrent_gen_output = net(*input)[self.gen_output_to_use] compare_real = gen_output compare_fake = recurrent_gen_output diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/steps/progressive_zoom.py index c0363778..f4d047f6 100644 --- a/codes/models/steps/progressive_zoom.py +++ b/codes/models/steps/progressive_zoom.py @@ -3,6 +3,7 @@ import random import torch import torchvision +from torch.cuda.amp import autocast from data.multiscale_dataset import build_multiscale_patch_index_map from models.steps.injectors import Injector @@ -52,7 +53,10 @@ class ProgressiveGeneratorInjector(Injector): ff_input = inputs.copy() ff_input[self.input_lq_index] = lq_input ff_input[self.recurrent_index] = recurrent_input - gen_out = gen(*ff_input) + + with autocast(enabled=self.env['opt']['fp16']): + gen_out = gen(*ff_input) + if isinstance(gen_out, torch.Tensor): gen_out = [gen_out] for i, out_key in enumerate(self.output): diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index fa7a9f45..a8e0bdb2 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -25,6 +25,7 @@ class ConfigurableStep(Module): self.loss_accumulator = LossAccumulator() self.optimizers = None self.scaler = GradScaler(enabled=self.opt['fp16']) + self.grads_generated = False self.injectors = [] if 'injectors' in self.step_opt.keys(): @@ -126,21 +127,20 @@ class ConfigurableStep(Module): self.env['training'] = train # Inject in any extra dependencies. - with autocast(enabled=self.opt['fp16']): - for inj in self.injectors: - # Don't do injections tagged with eval unless we are not in train mode. - if train and 'eval' in inj.opt.keys() and inj.opt['eval']: - continue - # Likewise, don't do injections tagged with train unless we are not in eval. - if not train and 'train' in inj.opt.keys() and inj.opt['train']: - continue - # Don't do injections tagged with 'after' or 'before' when we are out of spec. - if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ - 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: - continue - injected = inj(local_state) - local_state.update(injected) - new_state.update(injected) + for inj in self.injectors: + # Don't do injections tagged with eval unless we are not in train mode. + if train and 'eval' in inj.opt.keys() and inj.opt['eval']: + continue + # Likewise, don't do injections tagged with train unless we are not in eval. + if not train and 'train' in inj.opt.keys() and inj.opt['train']: + continue + # Don't do injections tagged with 'after' or 'before' when we are out of spec. + if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ + 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: + continue + injected = inj(local_state) + local_state.update(injected) + new_state.update(injected) if train and len(self.losses) > 0: # Finally, compute the losses. @@ -150,7 +150,6 @@ class ConfigurableStep(Module): # be very disruptive to a generator. if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: continue - l = loss(self.training_net, local_state) total_loss += l * self.weights[loss_name] # Record metrics. @@ -167,9 +166,8 @@ class ConfigurableStep(Module): total_loss = total_loss / self.env['mega_batch_factor'] # Get dem grads! - # Workaround for https://github.com/pytorch/pytorch/issues/37730 - with autocast(): - self.scaler.scale(total_loss).backward() + self.scaler.scale(total_loss).backward() + self.grads_generated = True # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients. @@ -179,6 +177,9 @@ class ConfigurableStep(Module): # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() # all self.optimizers. def do_step(self): + if not self.grads_generated: + return + self.grads_generated = False for opt in self.optimizers: # Optimizers can be opted out in the early stages of training. after = opt._config['after'] if 'after' in opt._config.keys() else 0 diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 2a9ab802..03175a1d 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -1,3 +1,5 @@ +from torch.cuda.amp import autocast + from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.steps.injectors import Injector @@ -24,10 +26,10 @@ def create_teco_injector(opt, env): return FlowAdjustment(opt, env) return None -def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin): +def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin, fp16): triplet = input_list[:, index:index+3] # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. - with torch.no_grad(): + with torch.no_grad() and autocast(enabled=fp16): first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float()) #first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float()) @@ -99,14 +101,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector): with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') + with autocast(enabled=self.env['opt']['fp16']): + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') # Resample does not work in FP16. recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) debug_index += 1 - gen_out = gen(*input) + + with autocast(enabled=self.env['opt']['fp16']): + gen_out = gen(*input) + if isinstance(gen_out, torch.Tensor): gen_out = [gen_out] for i, out_key in enumerate(self.output): @@ -121,14 +127,18 @@ class RecurrentImageGeneratorSequenceInjector(Injector): with torch.no_grad(): reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) - flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') + with autocast(enabled=self.env['opt']['fp16']): + flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') recurrent_input = self.resample(recurrent_input.float(), flowfield.float()) input[self.recurrent_index ] = recurrent_input if self.env['step'] % 50 == 0: self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index) debug_index += 1 - gen_out = gen(*input) + + with autocast(enabled=self.env['opt']['fp16']): + gen_out = gen(*input) + if isinstance(gen_out, torch.Tensor): gen_out = [gen_out] for i, out_key in enumerate(self.output): @@ -192,6 +202,7 @@ class TecoGanLoss(ConfigurableLoss): self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. def forward(self, _, state): + fp16 = self.env['opt']['fp16'] net = self.env['discriminators'][self.opt['discriminator']] flow_gen = self.env['generators'][self.image_flow_generator] real = state[self.opt['real']] @@ -200,10 +211,11 @@ class TecoGanLoss(ConfigurableLoss): lr = state[self.opt['lr_inputs']] l_total = 0 for i in range(sequence_len - 2): - real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) - fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) - d_fake = net(fake_sext) - d_real = net(real_sext) + real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16) + fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin, fp16) + with autocast(enabled=fp16): + d_fake = net(fake_sext) + d_real = net(real_sext) self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_real", torch.mean(d_real)))