Update Bert MoE
This commit is contained in:
parent
c397ebb013
commit
0a07df1e5b
|
@ -29,14 +29,14 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
|
|||
)
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
lprobs = lprobs.view(-1, lprobs.size(-1))
|
||||
target = model.get_targets(sample, net_output).view(-1)
|
||||
target = model.get_targets(sample, net_output)
|
||||
|
||||
if masked_tokens is not None:
|
||||
targets = targets[masked_tokens]
|
||||
target = target[masked_tokens]
|
||||
|
||||
nll_loss = F.nll_loss(
|
||||
lprobs,
|
||||
target,
|
||||
target.view(-1),
|
||||
ignore_index=self.padding_idx,
|
||||
reduction="sum" if reduce else "none",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user