Update Bert MoE

This commit is contained in:
shumingma 2023-03-07 21:21:48 -08:00
parent c397ebb013
commit 0a07df1e5b

View File

@ -29,14 +29,14 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
)
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1)
target = model.get_targets(sample, net_output)
if masked_tokens is not None:
targets = targets[masked_tokens]
target = target[masked_tokens]
nll_loss = F.nll_loss(
lprobs,
target,
target.view(-1),
ignore_index=self.padding_idx,
reduction="sum" if reduce else "none",
)