forked from mrq/DL-Art-School
Add mechanism to export grad norms
This commit is contained in:
parent
d9f8f92840
commit
70fa780edb
|
@ -407,6 +407,17 @@ class DiffusionTts(nn.Module):
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
groups = {
|
||||||
|
'minicoder': list(self.contextual_embedder.parameters()),
|
||||||
|
'input_blocks': list(self.input_blocks.parameters()),
|
||||||
|
'output_blocks': list(self.output_blocks.parameters()),
|
||||||
|
'middle_transformer': list(self.middle_block.parameters()),
|
||||||
|
'conditioning_encoder': list(self.conditioning_encoder.parameters())
|
||||||
|
}
|
||||||
|
if self.enable_unaligned_inputs:
|
||||||
|
groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters())
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
|
def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -181,6 +181,7 @@ class Trainer:
|
||||||
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
# because train_data is process-local while the opt variant represents all of the data fed across all GPUs.
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
self.total_training_data_encountered += batch_size
|
self.total_training_data_encountered += batch_size
|
||||||
|
will_log = self.current_step % opt['logger']['print_freq'] == 0
|
||||||
|
|
||||||
#### update learning rate
|
#### update learning rate
|
||||||
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter'])
|
||||||
|
@ -190,7 +191,7 @@ class Trainer:
|
||||||
print("Update LR: %f" % (time() - _t))
|
print("Update LR: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
self.model.feed_data(train_data, self.current_step)
|
self.model.feed_data(train_data, self.current_step)
|
||||||
self.model.optimize_parameters(self.current_step)
|
gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log)
|
||||||
if self._profile:
|
if self._profile:
|
||||||
print("Model feed + step: %f" % (time() - _t))
|
print("Model feed + step: %f" % (time() - _t))
|
||||||
_t = time()
|
_t = time()
|
||||||
|
@ -198,13 +199,14 @@ class Trainer:
|
||||||
#### log
|
#### log
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
self.dataset_debugger.update(train_data)
|
self.dataset_debugger.update(train_data)
|
||||||
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
if will_log and self.rank <= 0:
|
||||||
logs = {'step': self.current_step,
|
logs = {'step': self.current_step,
|
||||||
'samples': self.total_training_data_encountered,
|
'samples': self.total_training_data_encountered,
|
||||||
'megasamples': self.total_training_data_encountered / 1000000}
|
'megasamples': self.total_training_data_encountered / 1000000}
|
||||||
logs.update(self.model.get_current_log(self.current_step))
|
logs.update(self.model.get_current_log(self.current_step))
|
||||||
if self.dataset_debugger is not None:
|
if self.dataset_debugger is not None:
|
||||||
logs.update(self.dataset_debugger.get_debugging_map())
|
logs.update(self.dataset_debugger.get_debugging_map())
|
||||||
|
logs.update(gradient_norms_dict)
|
||||||
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
||||||
for v in self.model.get_current_learning_rate():
|
for v in self.model.get_current_learning_rate():
|
||||||
message += '{:.3e},'.format(v)
|
message += '{:.3e},'.format(v)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from time import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DataParallel
|
from torch.nn.parallel import DataParallel
|
||||||
|
@ -232,7 +233,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
||||||
|
|
||||||
|
|
||||||
def optimize_parameters(self, it, optimize=True):
|
def optimize_parameters(self, it, optimize=True, return_grad_norms=False):
|
||||||
|
grad_norms = {}
|
||||||
|
|
||||||
# Some models need to make parametric adjustments per-step. Do that here.
|
# Some models need to make parametric adjustments per-step. Do that here.
|
||||||
for net in self.networks.values():
|
for net in self.networks.values():
|
||||||
if hasattr(net.module, "update_for_step"):
|
if hasattr(net.module, "update_for_step"):
|
||||||
|
@ -300,6 +303,17 @@ class ExtensibleTrainer(BaseModel):
|
||||||
raise OverwrittenStateError(k, list(state.keys()))
|
raise OverwrittenStateError(k, list(state.keys()))
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
|
if return_grad_norms and train_step:
|
||||||
|
for name in nets_to_train:
|
||||||
|
model = self.networks[name]
|
||||||
|
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
|
||||||
|
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
|
||||||
|
else:
|
||||||
|
pgroups = {f'{name}_all_parameters': list(model.parameters())}
|
||||||
|
for name in pgroups.keys():
|
||||||
|
grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2)
|
||||||
|
|
||||||
|
|
||||||
# (Maybe) perform a step.
|
# (Maybe) perform a step.
|
||||||
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||||
self.consume_gradients(state, step, it)
|
self.consume_gradients(state, step, it)
|
||||||
|
@ -341,6 +355,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
os.makedirs(model_vdbg_dir, exist_ok=True)
|
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||||
net.module.visual_dbg(it, model_vdbg_dir)
|
net.module.visual_dbg(it, model_vdbg_dir)
|
||||||
|
|
||||||
|
return grad_norms
|
||||||
|
|
||||||
|
|
||||||
def consume_gradients(self, state, step, it):
|
def consume_gradients(self, state, step, it):
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user