forked from mrq/DL-Art-School
Do grad reduction
This commit is contained in:
parent
7c82e18c6c
commit
7213ad2b89
|
@ -545,8 +545,8 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
self.max_gumbel_temperature = max_gumbel_temperature
|
self.max_gumbel_temperature = max_gumbel_temperature
|
||||||
self.min_gumbel_temperature = min_gumbel_temperature
|
self.min_gumbel_temperature = min_gumbel_temperature
|
||||||
self.gumbel_temperature_decay = gumbel_temperature_decay
|
self.gumbel_temperature_decay = gumbel_temperature_decay
|
||||||
|
|
||||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
|
self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim)
|
||||||
|
self.num_losses_record = []
|
||||||
|
|
||||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||||
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
|
self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim)
|
||||||
|
@ -645,15 +645,20 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
|
|
||||||
num_losses = mask_time_indices.sum()
|
num_losses = mask_time_indices.sum()
|
||||||
if distributed.is_initialized():
|
if distributed.is_initialized():
|
||||||
num_losses = distributed.all_reduce(num_losses) / distributed.get_world_size()
|
distributed.all_reduce(num_losses)
|
||||||
|
num_losses = num_losses / distributed.get_world_size()
|
||||||
contrastive_loss = contrastive_loss / num_losses
|
contrastive_loss = contrastive_loss / num_losses
|
||||||
diversity_loss = diversity_loss / num_losses
|
diversity_loss = diversity_loss / num_losses
|
||||||
|
|
||||||
|
self.num_losses_record = num_losses.detach()
|
||||||
return contrastive_loss, diversity_loss
|
return contrastive_loss, diversity_loss
|
||||||
|
|
||||||
def before_step(self, it):
|
def after_backward(self, it):
|
||||||
# Unscale the grads by the total number of losses encountered.
|
if self.num_losses_record > 0:
|
||||||
pass
|
# Unscale the grads by the total number of losses encountered.
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.grad is not None:
|
||||||
|
p.grad.data.div_(self.num_losses_record)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -294,6 +294,11 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.batch_size_optimizer.focus(net)
|
self.batch_size_optimizer.focus(net)
|
||||||
for m in range(self.batch_factor):
|
for m in range(self.batch_factor):
|
||||||
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||||
|
# Call into post-backward hooks.
|
||||||
|
for name, net in self.networks.items():
|
||||||
|
if hasattr(net.module, "after_backward"):
|
||||||
|
net.module.after_backward(it)
|
||||||
|
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
if k not in new_states.keys():
|
if k not in new_states.keys():
|
||||||
new_states[k] = [v]
|
new_states[k] = [v]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user