forked from mrq/DL-Art-School
Enable amp to be disabled
This commit is contained in:
parent
c04f244802
commit
3027e6e27d
|
@ -102,7 +102,7 @@ class FullImageDataset(data.Dataset):
|
||||||
left = self.pick_along_range(w, square_size, .3)
|
left = self.pick_along_range(w, square_size, .3)
|
||||||
top = self.pick_along_range(w, square_size, .3)
|
top = self.pick_along_range(w, square_size, .3)
|
||||||
|
|
||||||
mask = np.zeros((h, w, 1), dtype=np.float)
|
mask = np.zeros((h, w, 1), dtype=image.dtype)
|
||||||
mask[top:top+square_size, left:left+square_size] = 1
|
mask[top:top+square_size, left:left+square_size] = 1
|
||||||
patch = image[top:top+square_size, left:left+square_size, :]
|
patch = image[top:top+square_size, left:left+square_size, :]
|
||||||
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
center = torch.tensor([top + square_size // 2, left + square_size // 2], dtype=torch.long)
|
||||||
|
|
|
@ -82,8 +82,14 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# Initialize amp.
|
# Initialize amp.
|
||||||
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
|
||||||
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
if 'amp_opt_level' in opt.keys():
|
||||||
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
self.env['amp'] = True
|
||||||
|
amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps,
|
||||||
|
self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps']))
|
||||||
|
else:
|
||||||
|
amp_nets = total_nets + [self.netF] + self.steps
|
||||||
|
amp_opts = self.optimizers
|
||||||
|
self.env['amp'] = False
|
||||||
|
|
||||||
# Unwrap steps & netF
|
# Unwrap steps & netF
|
||||||
self.netF = amp_nets[len(total_nets)]
|
self.netF = amp_nets[len(total_nets)]
|
||||||
|
|
|
@ -115,8 +115,11 @@ class ConfigurableStep(Module):
|
||||||
total_loss = total_loss / self.env['mega_batch_factor']
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
# Get dem grads!
|
# Get dem grads!
|
||||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
if self.env['amp']:
|
||||||
scaled_loss.backward()
|
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
total_loss.backward()
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
# we must release the gradients.
|
# we must release the gradients.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user