Fix Bert MoE
This commit is contained in:
parent
670113e446
commit
c397ebb013
|
@ -0,0 +1,8 @@
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
# automatically import any Python files in the criterions/ directory
|
||||||
|
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
||||||
|
if file.endswith(".py") and not file.startswith("_"):
|
||||||
|
file_name = file[: file.find(".py")]
|
||||||
|
importlib.import_module("criterions." + file_name)
|
|
@ -4,6 +4,7 @@
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairseq import metrics, utils
|
from fairseq import metrics, utils
|
||||||
from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig
|
from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig
|
||||||
|
@ -50,7 +51,7 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reduce_metrics(logging_outputs) -> None:
|
def reduce_metrics(logging_outputs) -> None:
|
||||||
"""Aggregate logging outputs from data parallel training."""
|
"""Aggregate logging outputs from data parallel training."""
|
||||||
MoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
|
MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
|
||||||
|
|
||||||
loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs)
|
loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs)
|
||||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
import criterions
|
||||||
from fairseq_cli.generate import cli_main
|
from fairseq_cli.generate import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
import criterions
|
||||||
from fairseq_cli.interactive import cli_main
|
from fairseq_cli.interactive import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
import criterions
|
||||||
from fairseq_cli.train import cli_main
|
from fairseq_cli.train import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user