Update Bert MoE
This commit is contained in:
parent
c397ebb013
commit
0a07df1e5b
|
@ -29,14 +29,14 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
|
||||||
)
|
)
|
||||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||||
lprobs = lprobs.view(-1, lprobs.size(-1))
|
lprobs = lprobs.view(-1, lprobs.size(-1))
|
||||||
target = model.get_targets(sample, net_output).view(-1)
|
target = model.get_targets(sample, net_output)
|
||||||
|
|
||||||
if masked_tokens is not None:
|
if masked_tokens is not None:
|
||||||
targets = targets[masked_tokens]
|
target = target[masked_tokens]
|
||||||
|
|
||||||
nll_loss = F.nll_loss(
|
nll_loss = F.nll_loss(
|
||||||
lprobs,
|
lprobs,
|
||||||
target,
|
target.view(-1),
|
||||||
ignore_index=self.padding_idx,
|
ignore_index=self.padding_idx,
|
||||||
reduction="sum" if reduce else "none",
|
reduction="sum" if reduce else "none",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user