Fix Bert MoE
This commit is contained in:
parent
670113e446
commit
c397ebb013
|
@ -0,0 +1,8 @@
|
|||
import importlib
|
||||
import os
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
for file in sorted(os.listdir(os.path.dirname(__file__))):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
file_name = file[: file.find(".py")]
|
||||
importlib.import_module("criterions." + file_name)
|
|
@ -4,6 +4,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig
|
||||
|
@ -50,7 +51,7 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion):
|
|||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
MoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
|
||||
MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs)
|
||||
|
||||
loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
import criterions
|
||||
from fairseq_cli.generate import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
import criterions
|
||||
from fairseq_cli.interactive import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
import criterions
|
||||
from fairseq_cli.train import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue
Block a user