forked from mrq/DL-Art-School
just take the mean...
This commit is contained in:
parent
6130391a85
commit
ee364f4eeb
|
@ -638,27 +638,29 @@ class ContrastiveTrainingWrapper(nn.Module):
|
|||
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
||||
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
||||
|
||||
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
||||
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="mean")
|
||||
# 7. compute diversity loss: \mathbf{L}_d
|
||||
num_codevectors = self.quantizer.num_codevectors
|
||||
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
||||
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
|
||||
|
||||
"""
|
||||
num_losses = mask_time_indices.sum()
|
||||
if distributed.is_initialized():
|
||||
distributed.all_reduce(num_losses)
|
||||
num_losses = num_losses / distributed.get_world_size()
|
||||
contrastive_loss = contrastive_loss
|
||||
diversity_loss = diversity_loss
|
||||
|
||||
self.num_losses_record = num_losses.detach()
|
||||
"""
|
||||
|
||||
return contrastive_loss, diversity_loss
|
||||
|
||||
"""
|
||||
def after_backward(self, it):
|
||||
if self.num_losses_record > 0:
|
||||
# Unscale the grads by the total number of losses encountered.
|
||||
for p in self.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.num_losses_record)
|
||||
"""
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
Loading…
Reference in New Issue
Block a user