From 7213ad2b89fe9bfb836b7483bf50572066b796f5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 May 2022 17:59:40 -0600 Subject: [PATCH] Do grad reduction --- codes/models/audio/mel2vec.py | 15 ++++++++++----- codes/trainer/ExtensibleTrainer.py | 5 +++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index ea0fff6b..285cf299 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -545,8 +545,8 @@ class ContrastiveTrainingWrapper(nn.Module): self.max_gumbel_temperature = max_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature self.gumbel_temperature_decay = gumbel_temperature_decay - self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim) + self.num_losses_record = [] # make sure that project_hid & project_q are initialized like normal linear layers self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) @@ -645,15 +645,20 @@ class ContrastiveTrainingWrapper(nn.Module): num_losses = mask_time_indices.sum() if distributed.is_initialized(): - num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size() + distributed.all_reduce(num_losses) + num_losses = num_losses / distributed.get_world_size() contrastive_loss = contrastive_loss / num_losses diversity_loss = diversity_loss / num_losses + self.num_losses_record = num_losses.detach() return contrastive_loss, diversity_loss - def before_step(self, it): - # Unscale the grads by the total number of losses encountered. - pass + def after_backward(self, it): + if self.num_losses_record > 0: + # Unscale the grads by the total number of losses encountered. + for p in self.parameters(): + if p.grad is not None: + p.grad.data.div_(self.num_losses_record) @register_model diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 89d5e753..9ecbb836 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -294,6 +294,11 @@ class ExtensibleTrainer(BaseModel): self.batch_size_optimizer.focus(net) for m in range(self.batch_factor): ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor)) + # Call into post-backward hooks. + for name, net in self.networks.items(): + if hasattr(net.module, "after_backward"): + net.module.after_backward(it) + for k, v in ns.items(): if k not in new_states.keys(): new_states[k] = [v]