From 3027e6e27d11b82f151604984e2d3e77a0977bfd Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 9 Sep 2020 10:45:59 -0600 Subject: [PATCH] Enable amp to be disabled --- codes/data/full_image_dataset.py | 2 +- codes/models/ExtensibleTrainer.py | 10 ++++++++-- codes/models/steps/steps.py | 7 +++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index 9e890dc8..e4f297ab 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -102,7 +102,7 @@ class FullImageDataset(data.Dataset): left = self.pick_along_range(w, square_size, .3) top = self.pick_along_range(w, square_size, .3) - mask = np.zeros((h, w, 1), dtype=np.float) + mask = np.zeros((h, w, 1), dtype=image.dtype) mask[top:top+square_size, left:left+square_size] = 1 patch = image[top:top+square_size, left:left+square_size, :] center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 985d017f..20130840 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -82,8 +82,14 @@ class ExtensibleTrainer(BaseModel): # Initialize amp. total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] - amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, - self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) + if 'amp_opt_level' in opt.keys(): + self.env['amp'] = True + amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, + self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) + else: + amp_nets = total_nets + [self.netF] + self.steps + amp_opts = self.optimizers + self.env['amp'] = False # Unwrap steps & netF self.netF = amp_nets[len(total_nets)] diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 9e8080be..e7de2e8d 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -115,8 +115,11 @@ class ConfigurableStep(Module): total_loss = total_loss / self.env['mega_batch_factor'] # Get dem grads! - with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: - scaled_loss.backward() + if self.env['amp']: + with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: + scaled_loss.backward() + else: + total_loss.backward() # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients.