Enable amp to be disabled

This commit is contained in:
James Betker 2020-09-09 10:45:59 -06:00
parent c04f244802
commit 3027e6e27d
3 changed files with 14 additions and 5 deletions

View File

@ -102,7 +102,7 @@ class FullImageDataset(data.Dataset):
left = self.pick_along_range(w, square_size, .3)
top = self.pick_along_range(w, square_size, .3)
mask = np.zeros((h, w, 1), dtype=np.float)
mask = np.zeros((h, w, 1), dtype=image.dtype)
mask[top:top+square_size, left:left+square_size] = 1
patch = image[top:top+square_size, left:left+square_size, :]
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)

View File

@ -82,8 +82,14 @@ class ExtensibleTrainer(BaseModel):
# Initialize amp.
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
if 'amp_opt_level' in opt.keys():
self.env['amp'] = True
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
else:
amp_nets = total_nets + [self.netF] + self.steps
amp_opts = self.optimizers
self.env['amp'] = False
# Unwrap steps & netF
self.netF = amp_nets[len(total_nets)]

View File

@ -115,8 +115,11 @@ class ConfigurableStep(Module):
total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads!
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
scaled_loss.backward()
if self.env['amp']:
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients.