From 0a07df1e5bf1d3e6d640d8e306c900575e1c0fb3 Mon Sep 17 00:00:00 2001 From: shumingma Date: Tue, 7 Mar 2023 21:21:48 -0800 Subject: [PATCH] Update Bert MoE --- examples/fairseq/criterions/masked_lm_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/fairseq/criterions/masked_lm_moe.py b/examples/fairseq/criterions/masked_lm_moe.py index dd41cba..8e25442 100644 --- a/examples/fairseq/criterions/masked_lm_moe.py +++ b/examples/fairseq/criterions/masked_lm_moe.py @@ -29,14 +29,14 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion): ) lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.view(-1, lprobs.size(-1)) - target = model.get_targets(sample, net_output).view(-1) + target = model.get_targets(sample, net_output) if masked_tokens is not None: - targets = targets[masked_tokens] + target = target[masked_tokens] nll_loss = F.nll_loss( lprobs, - target, + target.view(-1), ignore_index=self.padding_idx, reduction="sum" if reduce else "none", )