Update Bert MoE

This commit is contained in:
shumingma 2023-03-07 21:21:48 -08:00
parent c397ebb013
commit 0a07df1e5b

View File

@ -29,14 +29,14 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
) )
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1)) lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1) target = model.get_targets(sample, net_output)
if masked_tokens is not None: if masked_tokens is not None:
targets = targets[masked_tokens] target = target[masked_tokens]
nll_loss = F.nll_loss( nll_loss = F.nll_loss(
lprobs, lprobs,
target, target.view(-1),
ignore_index=self.padding_idx, ignore_index=self.padding_idx,
reduction="sum" if reduce else "none", reduction="sum" if reduce else "none",
) )