just take the mean...
This commit is contained in:
parent
6130391a85
commit
ee364f4eeb
|
@ -638,27 +638,29 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
||||||
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
||||||
|
|
||||||
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="mean")
|
||||||
# 7. compute diversity loss: \mathbf{L}_d
|
# 7. compute diversity loss: \mathbf{L}_d
|
||||||
num_codevectors = self.quantizer.num_codevectors
|
num_codevectors = self.quantizer.num_codevectors
|
||||||
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
|
||||||
|
|
||||||
|
"""
|
||||||
num_losses = mask_time_indices.sum()
|
num_losses = mask_time_indices.sum()
|
||||||
if distributed.is_initialized():
|
if distributed.is_initialized():
|
||||||
distributed.all_reduce(num_losses)
|
distributed.all_reduce(num_losses)
|
||||||
num_losses = num_losses / distributed.get_world_size()
|
num_losses = num_losses / distributed.get_world_size()
|
||||||
contrastive_loss = contrastive_loss
|
|
||||||
diversity_loss = diversity_loss
|
|
||||||
|
|
||||||
self.num_losses_record = num_losses.detach()
|
self.num_losses_record = num_losses.detach()
|
||||||
|
"""
|
||||||
|
|
||||||
return contrastive_loss, diversity_loss
|
return contrastive_loss, diversity_loss
|
||||||
|
|
||||||
|
"""
|
||||||
def after_backward(self, it):
|
def after_backward(self, it):
|
||||||
if self.num_losses_record > 0:
|
if self.num_losses_record > 0:
|
||||||
# Unscale the grads by the total number of losses encountered.
|
# Unscale the grads by the total number of losses encountered.
|
||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
p.grad.data.div_(self.num_losses_record)
|
p.grad.data.div_(self.num_losses_record)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user