Fix Bert MoE

This commit is contained in:
shumingma 2023-03-07 21:11:05 -08:00
parent 670113e446
commit c397ebb013
5 changed files with 13 additions and 1 deletions

View File

@ -0,0 +1,8 @@
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in sorted(os.listdir(os.path.dirname(__file__))):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("criterions." + file_name)

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig
@ -50,7 +51,7 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
MoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)

View File

@ -4,6 +4,7 @@
# flake8: noqa
import models
import tasks
import criterions
from fairseq_cli.generate import cli_main
if __name__ == "__main__":

View File

@ -4,6 +4,7 @@
# flake8: noqa
import models
import tasks
import criterions
from fairseq_cli.interactive import cli_main
if __name__ == "__main__":

View File

@ -4,6 +4,7 @@
# flake8: noqa
import models
import tasks
import criterions
from fairseq_cli.train import cli_main
if __name__ == "__main__":