diff --git a/examples/fairseq/models/bert.py b/examples/fairseq/models/bert.py index f804973..c7c652d 100644 --- a/examples/fairseq/models/bert.py +++ b/examples/fairseq/models/bert.py @@ -130,6 +130,9 @@ class BertConfig(FairseqDataclass): tpu: bool = II("common.tpu") rel_pos_buckets: int = field(default=0, metadata={"help": ""}) max_rel_pos: int = field(default=0, metadata={"help": ""}) + use_xmoe: Optional[bool] = field( + default=False, + ) moe_freq: int = field( default=0, metadata={"help": "Frequency at which we insert MoE Transformer layers"}, diff --git a/examples/fairseq/tasks/data/mlm_loader.py b/examples/fairseq/tasks/data/mlm_loader.py index eb9cd72..510f654 100644 --- a/examples/fairseq/tasks/data/mlm_loader.py +++ b/examples/fairseq/tasks/data/mlm_loader.py @@ -166,6 +166,9 @@ class MLMLoader(BaseBatchGen): mlm_target_max_length = max([len(x[1]) for x in batch]) s2s_source_max_length = max([len(x[2]) for x in batch]) s2s_target_max_length = max([len(x[3]) for x in batch]) + if self.args.pad_to_max_length: + mlm_source_max_length = self.args.tokens_per_sample + mlm_target_max_length = self.args.tokens_per_sample mlm_source_ids = np.full( shape=(batch_size, mlm_source_max_length), diff --git a/examples/fairseq/tasks/pretraining.py b/examples/fairseq/tasks/pretraining.py index 2d32127..2022907 100644 --- a/examples/fairseq/tasks/pretraining.py +++ b/examples/fairseq/tasks/pretraining.py @@ -117,6 +117,9 @@ class PretrainingConfig(FairseqDataclass): default="", metadata={"help": ""}, ) + pad_to_max_length: bool = field( + default=False, + ) @register_task("pretraining", dataclass=PretrainingConfig)