Do grad reduction
This commit is contained in:
parent
7c82e18c6c
commit
7213ad2b89
|
@ -545,8 +545,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
self.max_gumbel_temperature = max_gumbel_temperature
|
||||
self.min_gumbel_temperature = min_gumbel_temperature
|
||||
self.gumbel_temperature_decay = gumbel_temperature_decay
|
||||
|
||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
|
||||
self.num_losses_record = []
|
||||
|
||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
|
||||
|
@ -645,15 +645,20 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
|
||||
num_losses = mask_time_indices.sum()
|
||||
if distributed.is_initialized():
|
||||
num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size()
|
||||
distributed.all_reduce(num_losses)
|
||||
num_losses = num_losses / distributed.get_world_size()
|
||||
contrastive_loss = contrastive_loss / num_losses
|
||||
diversity_loss = diversity_loss / num_losses
|
||||
|
||||
self.num_losses_record = num_losses.detach()
|
||||
return contrastive_loss, diversity_loss
|
||||
|
||||
def before_step(self, it):
|
||||
# Unscale the grads by the total number of losses encountered.
|
||||
pass
|
||||
def after_backward(self, it):
|
||||
if self.num_losses_record > 0:
|
||||
# Unscale the grads by the total number of losses encountered.
|
||||
for p in self.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.num_losses_record)
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -294,6 +294,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.batch_size_optimizer.focus(net)
|
||||
for m in range(self.batch_factor):
|
||||
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||
# Call into post-backward hooks.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net.module, "after_backward"):
|
||||
net.module.after_backward(it)
|
||||
|
||||
for k, v in ns.items():
|
||||
if k not in new_states.keys():
|
||||
new_states[k] = [v]
|
||||
|
|
Loading…
Reference in New Issue
Block a user