torchscale/examples/fairseq/tasks/__init__.py
2022-11-23 08:36:55 -08:00

35 lines
1.3 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import argparse
import importlib
import os
# register dataclass
TASK_DATACLASS_REGISTRY = {}
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
# automatically import any Python files in the tasks/ directory
tasks_dir = os.path.dirname(__file__)
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
task_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("tasks." + task_name)
# expose `task_parser` for sphinx
if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group("Task name")
# fmt: off
group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``')
# fmt: on
group_args = parser.add_argument_group("Additional command-line arguments")
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser