auto grad "lr" scaling
This commit is contained in:
parent
e5d97dfd56
commit
78bba690de
|
@ -1,6 +1,7 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
from math import sqrt
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
|
@ -62,6 +63,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
|
||||
self.auto_scale_grads = opt_get(opt, ['automatically_scale_grads_for_fanin'], False)
|
||||
self.auto_scale_basis = opt_get(opt, ['automatically_scale_base_layer_size'], 1024)
|
||||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
|
@ -315,11 +318,6 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# (Maybe) perform a step.
|
||||
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||
# Call into pre-step hooks.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net.module, "before_step"):
|
||||
net.module.before_step(it)
|
||||
|
||||
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
|
||||
# This is needed to accurately log the grad norms.
|
||||
for opt in step.optimizers:
|
||||
|
@ -327,6 +325,27 @@ class ExtensibleTrainer(BaseModel):
|
|||
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
|
||||
step.scaler.unscale_(opt)
|
||||
|
||||
# Call into pre-step hooks.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net.module, "before_step"):
|
||||
net.module.before_step(it)
|
||||
|
||||
if self.auto_scale_grads:
|
||||
asb = sqrt(self.auto_scale_basis)
|
||||
for net in self.networks.values():
|
||||
for mod in net.modules():
|
||||
fan_in = -1
|
||||
if isinstance(mod, nn.Linear):
|
||||
fan_in = mod.weight.data.shape[1]
|
||||
elif isinstance(mod, nn.Conv1d):
|
||||
fan_in = mod.weight.data.shape[0]
|
||||
elif isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Conv3d):
|
||||
assert "Not yet implemented!"
|
||||
if fan_in != -1:
|
||||
p = mod.weight
|
||||
if hasattr(p, 'grad') and p.grad is not None:
|
||||
p.grad = p.grad * asb / sqrt(fan_in)
|
||||
|
||||
if return_grad_norms and train_step:
|
||||
for name in nets_to_train:
|
||||
model = self.networks[name]
|
||||
|
|
Loading…
Reference in New Issue
Block a user