fx bert moe

This commit is contained in:
buaahsh 2023-03-05 07:43:58 +00:00
parent 32cb51ae38
commit bc140c65bb
3 changed files with 9 additions and 0 deletions

View File

@ -130,6 +130,9 @@ class BertConfig(FairseqDataclass):
tpu: bool = II("common.tpu")
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
max_rel_pos: int = field(default=0, metadata={"help": ""})
use_xmoe: Optional[bool] = field(
default=False,
)
moe_freq: int = field(
default=0,
metadata={"help": "Frequency at which we insert MoE Transformer layers"},

View File

@ -166,6 +166,9 @@ class MLMLoader(BaseBatchGen):
mlm_target_max_length = max([len(x[1]) for x in batch])
s2s_source_max_length = max([len(x[2]) for x in batch])
s2s_target_max_length = max([len(x[3]) for x in batch])
if self.args.pad_to_max_length:
mlm_source_max_length = self.args.tokens_per_sample
mlm_target_max_length = self.args.tokens_per_sample
mlm_source_ids = np.full(
shape=(batch_size, mlm_source_max_length),

View File

@ -117,6 +117,9 @@ class PretrainingConfig(FairseqDataclass):
default="",
metadata={"help": ""},
)
pad_to_max_length: bool = field(
default=False,
)
@register_task("pretraining", dataclass=PretrainingConfig)