Enable amp to be disabled
This commit is contained in:
parent
c04f244802
commit
3027e6e27d
|
@ -102,7 +102,7 @@ class FullImageDataset(data.Dataset):
|
|||
left = self.pick_along_range(w, square_size, .3)
|
||||
top = self.pick_along_range(w, square_size, .3)
|
||||
|
||||
mask = np.zeros((h, w, 1), dtype=np.float)
|
||||
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
||||
mask[top:top+square_size, left:left+square_size] = 1
|
||||
patch = image[top:top+square_size, left:left+square_size, :]
|
||||
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||
|
|
|
@ -82,8 +82,14 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Initialize amp.
|
||||
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||
if 'amp_opt_level' in opt.keys():
|
||||
self.env['amp'] = True
|
||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||
else:
|
||||
amp_nets = total_nets + [self.netF] + self.steps
|
||||
amp_opts = self.optimizers
|
||||
self.env['amp'] = False
|
||||
|
||||
# Unwrap steps & netF
|
||||
self.netF = amp_nets[len(total_nets)]
|
||||
|
|
|
@ -115,8 +115,11 @@ class ConfigurableStep(Module):
|
|||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
if self.env['amp']:
|
||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
|
|
Loading…
Reference in New Issue
Block a user