Do grad reduction

This commit is contained in:
James Betker 2022-05-17 17:59:40 -06:00
parent 7c82e18c6c
commit 7213ad2b89
2 changed files with 15 additions and 5 deletions

View File

@ -545,8 +545,8 @@ class ContrastiveTrainingWrapper(nn.Module):
self.max_gumbel_temperature = max_gumbel_temperature
self.min_gumbel_temperature = min_gumbel_temperature
self.gumbel_temperature_decay = gumbel_temperature_decay
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
self.num_losses_record = []
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
@ -645,15 +645,20 @@ class ContrastiveTrainingWrapper(nn.Module):
num_losses = mask_time_indices.sum()
if distributed.is_initialized():
num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size()
distributed.all_reduce(num_losses)
num_losses = num_losses / distributed.get_world_size()
contrastive_loss = contrastive_loss / num_losses
diversity_loss = diversity_loss / num_losses
self.num_losses_record = num_losses.detach()
return contrastive_loss, diversity_loss
def before_step(self, it):
# Unscale the grads by the total number of losses encountered.
pass
def after_backward(self, it):
if self.num_losses_record > 0:
# Unscale the grads by the total number of losses encountered.
for p in self.parameters():
if p.grad is not None:
p.grad.data.div_(self.num_losses_record)
@register_model

View File

@ -294,6 +294,11 @@ class ExtensibleTrainer(BaseModel):
self.batch_size_optimizer.focus(net)
for m in range(self.batch_factor):
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
# Call into post-backward hooks.
for name, net in self.networks.items():
if hasattr(net.module, "after_backward"):
net.module.after_backward(it)
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]