forked from mrq/DL-Art-School
auto grad "lr" scaling
This commit is contained in:
parent
e5d97dfd56
commit
78bba690de
|
@ -1,6 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from math import sqrt
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -62,6 +63,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||||
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
|
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
|
||||||
|
self.auto_scale_grads = opt_get(opt, ['automatically_scale_grads_for_fanin'], False)
|
||||||
|
self.auto_scale_basis = opt_get(opt, ['automatically_scale_base_layer_size'], 1024)
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
|
@ -315,11 +318,6 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# (Maybe) perform a step.
|
# (Maybe) perform a step.
|
||||||
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||||
# Call into pre-step hooks.
|
|
||||||
for name, net in self.networks.items():
|
|
||||||
if hasattr(net.module, "before_step"):
|
|
||||||
net.module.before_step(it)
|
|
||||||
|
|
||||||
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
|
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
|
||||||
# This is needed to accurately log the grad norms.
|
# This is needed to accurately log the grad norms.
|
||||||
for opt in step.optimizers:
|
for opt in step.optimizers:
|
||||||
|
@ -327,6 +325,27 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
|
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
|
||||||
step.scaler.unscale_(opt)
|
step.scaler.unscale_(opt)
|
||||||
|
|
||||||
|
# Call into pre-step hooks.
|
||||||
|
for name, net in self.networks.items():
|
||||||
|
if hasattr(net.module, "before_step"):
|
||||||
|
net.module.before_step(it)
|
||||||
|
|
||||||
|
if self.auto_scale_grads:
|
||||||
|
asb = sqrt(self.auto_scale_basis)
|
||||||
|
for net in self.networks.values():
|
||||||
|
for mod in net.modules():
|
||||||
|
fan_in = -1
|
||||||
|
if isinstance(mod, nn.Linear):
|
||||||
|
fan_in = mod.weight.data.shape[1]
|
||||||
|
elif isinstance(mod, nn.Conv1d):
|
||||||
|
fan_in = mod.weight.data.shape[0]
|
||||||
|
elif isinstance(mod, nn.Conv2d) or isinstance(mod, nn.Conv3d):
|
||||||
|
assert "Not yet implemented!"
|
||||||
|
if fan_in != -1:
|
||||||
|
p = mod.weight
|
||||||
|
if hasattr(p, 'grad') and p.grad is not None:
|
||||||
|
p.grad = p.grad * asb / sqrt(fan_in)
|
||||||
|
|
||||||
if return_grad_norms and train_step:
|
if return_grad_norms and train_step:
|
||||||
for name in nets_to_train:
|
for name in nets_to_train:
|
||||||
model = self.networks[name]
|
model = self.networks[name]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user