From 78bba690ded3b9498d036cf9d35c70666d51fdce Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 8 Jul 2022 00:38:25 -0600 Subject: [PATCH] auto grad "lr" scaling --- codes/trainer/ExtensibleTrainer.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 40564d6b..79925764 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -1,6 +1,7 @@ import copy import logging import os +from math import sqrt from time import time import torch @@ -62,6 +63,8 @@ class ExtensibleTrainer(BaseModel): self.checkpointing_cache = opt['checkpointing_enabled'] self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None) self.batch_size_optimizer = create_batch_size_optimizer(train_opt) + self.auto_scale_grads = opt_get(opt, ['automatically_scale_grads_for_fanin'], False) + self.auto_scale_basis = opt_get(opt, ['automatically_scale_base_layer_size'], 1024) self.netsG = {} self.netsD = {} @@ -315,11 +318,6 @@ class ExtensibleTrainer(BaseModel): # (Maybe) perform a step. if train_step and optimize and self.batch_size_optimizer.should_step(it): - # Call into pre-step hooks. - for name, net in self.networks.items(): - if hasattr(net.module, "before_step"): - net.module.before_step(it) - # Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point) # This is needed to accurately log the grad norms. for opt in step.optimizers: @@ -327,6 +325,27 @@ class ExtensibleTrainer(BaseModel): if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED: step.scaler.unscale_(opt) + # Call into pre-step hooks. + for name, net in self.networks.items(): + if hasattr(net.module, "before_step"): + net.module.before_step(it) + + if self.auto_scale_grads: + asb = sqrt(self.auto_scale_basis) + for net in self.networks.values(): + for mod in net.modules(): + fan_in = -1 + if isinstance(mod, nn.Linear): + fan_in = mod.weight.data.shape[1] + elif isinstance(mod, nn.Conv1d): + fan_in = mod.weight.data.shape[0] + elif isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Conv3d): + assert "Not yet implemented!" + if fan_in != -1: + p = mod.weight + if hasattr(p, 'grad') and p.grad is not None: + p.grad = p.grad * asb / sqrt(fan_in) + if return_grad_norms and train_step: for name in nets_to_train: model = self.networks[name]