diff --git a/codes/data/full_image_dataset.py b/codes/data/full_image_dataset.py index 9e890dc8..e4f297ab 100644 --- a/codes/data/full_image_dataset.py +++ b/codes/data/full_image_dataset.py @@ -102,7 +102,7 @@ class FullImageDataset(data.Dataset): left = self.pick_along_range(w, square_size, .3) top = self.pick_along_range(w, square_size, .3) - mask = np.zeros((h, w, 1), dtype=np.float) + mask = np.zeros((h, w, 1), dtype=image.dtype) mask[top:top+square_size, left:left+square_size] = 1 patch = image[top:top+square_size, left:left+square_size, :] center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 985d017f..20130840 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -82,8 +82,14 @@ class ExtensibleTrainer(BaseModel): # Initialize amp. total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] - amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, - self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) + if 'amp_opt_level' in opt.keys(): + self.env['amp'] = True + amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, + self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) + else: + amp_nets = total_nets + [self.netF] + self.steps + amp_opts = self.optimizers + self.env['amp'] = False # Unwrap steps & netF self.netF = amp_nets[len(total_nets)] diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 9e8080be..e7de2e8d 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -115,8 +115,11 @@ class ConfigurableStep(Module): total_loss = total_loss / self.env['mega_batch_factor'] # Get dem grads! - with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: - scaled_loss.backward() + if self.env['amp']: + with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss: + scaled_loss.backward() + else: + total_loss.backward() # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients.