auto grad "lr" scaling

This commit is contained in:
James Betker 2022-07-08 00:38:25 -06:00
parent e5d97dfd56
commit 78bba690de

View File

@ -1,6 +1,7 @@
import copy
import logging
import os
from math import sqrt
from time import time
import torch
@ -62,6 +63,8 @@ class ExtensibleTrainer(BaseModel):
self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
self.auto_scale_grads = opt_get(opt, ['automatically_scale_grads_for_fanin'], False)
self.auto_scale_basis = opt_get(opt, ['automatically_scale_base_layer_size'], 1024)
self.netsG = {}
self.netsD = {}
@ -315,11 +318,6 @@ class ExtensibleTrainer(BaseModel):
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
# This is needed to accurately log the grad norms.
for opt in step.optimizers:
@ -327,6 +325,27 @@ class ExtensibleTrainer(BaseModel):
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
step.scaler.unscale_(opt)
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
if self.auto_scale_grads:
asb = sqrt(self.auto_scale_basis)
for net in self.networks.values():
for mod in net.modules():
fan_in = -1
if isinstance(mod, nn.Linear):
fan_in = mod.weight.data.shape[1]
elif isinstance(mod, nn.Conv1d):
fan_in = mod.weight.data.shape[0]
elif isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Conv3d):
assert "Not yet implemented!"
if fan_in != -1:
p = mod.weight
if hasattr(p, 'grad') and p.grad is not None:
p.grad = p.grad * asb / sqrt(fan_in)
if return_grad_norms and train_step:
for name in nets_to_train:
model = self.networks[name]