diff --git a/examples/fairseq/criterions/__init__.py b/examples/fairseq/criterions/__init__.py index e69de29..9901f27 100644 --- a/examples/fairseq/criterions/__init__.py +++ b/examples/fairseq/criterions/__init__.py @@ -0,0 +1,8 @@ +import importlib +import os + +# automatically import any Python files in the criterions/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("criterions." + file_name) \ No newline at end of file diff --git a/examples/fairseq/criterions/masked_lm_moe.py b/examples/fairseq/criterions/masked_lm_moe.py index 88d4724..dd41cba 100644 --- a/examples/fairseq/criterions/masked_lm_moe.py +++ b/examples/fairseq/criterions/masked_lm_moe.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import math +import torch import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import MoECriterion, register_criterion, MoECriterionConfig @@ -50,7 +51,7 @@ class MaskedLMMoECrossEntropyCriterion(MoECriterion): @staticmethod def reduce_metrics(logging_outputs) -> None: """Aggregate logging outputs from data parallel training.""" - MoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs) + MaskedLMMoECrossEntropyCriterion.reduce_moe_metrics(logging_outputs) loss_sum = sum(log.get("inner_loss", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) diff --git a/examples/fairseq/generate.py b/examples/fairseq/generate.py index 37b8945..3202873 100644 --- a/examples/fairseq/generate.py +++ b/examples/fairseq/generate.py @@ -4,6 +4,7 @@ # flake8: noqa import models import tasks +import criterions from fairseq_cli.generate import cli_main if __name__ == "__main__": diff --git a/examples/fairseq/interactive.py b/examples/fairseq/interactive.py index dca22d3..821d617 100644 --- a/examples/fairseq/interactive.py +++ b/examples/fairseq/interactive.py @@ -4,6 +4,7 @@ # flake8: noqa import models import tasks +import criterions from fairseq_cli.interactive import cli_main if __name__ == "__main__": diff --git a/examples/fairseq/train.py b/examples/fairseq/train.py index 0b0404e..35202fd 100644 --- a/examples/fairseq/train.py +++ b/examples/fairseq/train.py @@ -4,6 +4,7 @@ # flake8: noqa import models import tasks +import criterions from fairseq_cli.train import cli_main if __name__ == "__main__":