diff --git a/examples/fairseq/criterions/masked_lm_moe.py b/examples/fairseq/criterions/masked_lm_moe.py index dd41cba..8e25442 100644 --- a/examples/fairseq/criterions/masked_lm_moe.py +++ b/examples/fairseq/criterions/masked_lm_moe.py @@ -29,14 +29,14 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion): ) lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) - target = model.get_targets(sample, net_output).view(-1) + target = model.get_targets(sample, net_output) if masked_tokens is not None: - targets = targets[masked_tokens] + target = target[masked_tokens] nll_loss = F.nll_loss( lprobs, - target, + target.view(-1), ignore_index=self.padding_idx, reduction="sum" if reduce else "none", )