Fix Bert MoE

This commit is contained in:
shumingma 2023-03-07 21:11:05 -08:00
parent 670113e446
commit c397ebb013
5 changed files with 13 additions and 1 deletions

View File

@ -0,0 +1,8 @@
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("criterions." + file_name)

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import metrics, utils from fairseq import metrics, utils
from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig
@ -50,7 +51,7 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
@staticmethod @staticmethod
def reduce_metrics(logging_outputs) -> None: def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
MoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs) MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs) loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)

View File

@ -4,6 +4,7 @@
# flake8: noqa # flake8: noqa
import models import models
import tasks import tasks
import criterions
from fairseq_cli.generate import cli_main from fairseq_cli.generate import cli_main
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4,6 +4,7 @@
# flake8: noqa # flake8: noqa
import models import models
import tasks import tasks
import criterions
from fairseq_cli.interactive import cli_main from fairseq_cli.interactive import cli_main
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4,6 +4,7 @@
# flake8: noqa # flake8: noqa
import models import models
import tasks import tasks
import criterions
from fairseq_cli.train import cli_main from fairseq_cli.train import cli_main
if __name__ == "__main__": if __name__ == "__main__":