diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index ea0fff6b..285cf299 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -545,8 +545,8 @@ class ContrastiveTrainingWrapper(nn.Module): self.max_gumbel_temperature = max_gumbel_temperature self.min_gumbel_temperature = min_gumbel_temperature self.gumbel_temperature_decay = gumbel_temperature_decay - self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim) + self.num_losses_record = [] # make sure that project_hid & project_q are initialized like normal linear layers self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) @@ -645,15 +645,20 @@ class ContrastiveTrainingWrapper(nn.Module): num_losses = mask_time_indices.sum() if distributed.is_initialized(): - num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size() + distributed.all_reduce(num_losses) + num_losses = num_losses / distributed.get_world_size() contrastive_loss = contrastive_loss / num_losses diversity_loss = diversity_loss / num_losses + self.num_losses_record = num_losses.detach() return contrastive_loss, diversity_loss - def before_step(self, it): - # Unscale the grads by the total number of losses encountered. - pass + def after_backward(self, it): + if self.num_losses_record > 0: + # Unscale the grads by the total number of losses encountered. + for p in self.parameters(): + if p.grad is not None: + p.grad.data.div_(self.num_losses_record) @register_model diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 89d5e753..9ecbb836 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -294,6 +294,11 @@ class ExtensibleTrainer(BaseModel): self.batch_size_optimizer.focus(net) for m in range(self.batch_factor): ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor)) + # Call into post-backward hooks. + for name, net in self.networks.items(): + if hasattr(net.module, "after_backward"): + net.module.after_backward(it) + for k, v in ns.items(): if k not in new_states.keys(): new_states[k] = [v]