From 70fa780edbb33a40b86acbb47c8a339daf0a1f05 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 1 Mar 2022 20:19:52 -0700 Subject: [PATCH] Add mechanism to export grad norms --- codes/models/gpt_voice/unet_diffusion_tts7.py | 11 +++++++++++ codes/train.py | 6 ++++-- codes/trainer/ExtensibleTrainer.py | 18 +++++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index f99ce990..ed5d696e 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -407,6 +407,17 @@ class DiffusionTts(nn.Module): zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)), ) + def get_grad_norm_parameter_groups(self): + groups = { + 'minicoder': list(self.contextual_embedder.parameters()), + 'input_blocks': list(self.input_blocks.parameters()), + 'output_blocks': list(self.output_blocks.parameters()), + 'middle_transformer': list(self.middle_block.parameters()), + 'conditioning_encoder': list(self.conditioning_encoder.parameters()) + } + if self.enable_unaligned_inputs: + groups['unaligned_encoder'] = list(self.unaligned_encoder.parameters()) + def forward(self, x, timesteps, tokens=None, conditioning_input=None, lr_input=None, unaligned_input=None, conditioning_free=False): """ diff --git a/codes/train.py b/codes/train.py index 58e840d0..adfb7784 100644 --- a/codes/train.py +++ b/codes/train.py @@ -181,6 +181,7 @@ class Trainer: # because train_data is process-local while the opt variant represents all of the data fed across all GPUs. self.current_step += 1 self.total_training_data_encountered += batch_size + will_log = self.current_step % opt['logger']['print_freq'] == 0 #### update learning rate self.model.update_learning_rate(self.current_step, warmup_iter=opt['train']['warmup_iter']) @@ -190,7 +191,7 @@ class Trainer: print("Update LR: %f" % (time() - _t)) _t = time() self.model.feed_data(train_data, self.current_step) - self.model.optimize_parameters(self.current_step) + gradient_norms_dict = self.model.optimize_parameters(self.current_step, return_grad_norms=will_log) if self._profile: print("Model feed + step: %f" % (time() - _t)) _t = time() @@ -198,13 +199,14 @@ class Trainer: #### log if self.dataset_debugger is not None: self.dataset_debugger.update(train_data) - if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0: + if will_log and self.rank <= 0: logs = {'step': self.current_step, 'samples': self.total_training_data_encountered, 'megasamples': self.total_training_data_encountered / 1000000} logs.update(self.model.get_current_log(self.current_step)) if self.dataset_debugger is not None: logs.update(self.dataset_debugger.get_debugging_map()) + logs.update(gradient_norms_dict) message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step) for v in self.model.get_current_learning_rate(): message += '{:.3e},'.format(v) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 4f0c8b16..3ef20e07 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -1,6 +1,7 @@ import copy import logging import os +from time import time import torch from torch.nn.parallel import DataParallel @@ -232,7 +233,9 @@ class ExtensibleTrainer(BaseModel): self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen] - def optimize_parameters(self, it, optimize=True): + def optimize_parameters(self, it, optimize=True, return_grad_norms=False): + grad_norms = {} + # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net.module, "update_for_step"): @@ -300,6 +303,17 @@ class ExtensibleTrainer(BaseModel): raise OverwrittenStateError(k, list(state.keys())) state[k] = v + if return_grad_norms and train_step: + for name in nets_to_train: + model = self.networks[name] + if hasattr(model.module, 'get_grad_norm_parameter_groups'): + pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()} + else: + pgroups = {f'{name}_all_parameters': list(model.parameters())} + for name in pgroups.keys(): + grad_norms[name] = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in pgroups[name]]), 2) + + # (Maybe) perform a step. if train_step and optimize and self.batch_size_optimizer.should_step(it): self.consume_gradients(state, step, it) @@ -341,6 +355,8 @@ class ExtensibleTrainer(BaseModel): os.makedirs(model_vdbg_dir, exist_ok=True) net.module.visual_dbg(it, model_vdbg_dir) + return grad_norms + def consume_gradients(self, state, step, it): [e.before_optimize(state) for e in self.experiments]