commit
a491af1113
|
@ -10,7 +10,7 @@ pip install -e .
|
||||||
pip install git+https://github.com/shumingma/fairseq.git@moe
|
pip install git+https://github.com/shumingma/fairseq.git@moe
|
||||||
pip install git+https://github.com/shumingma/infinibatch.git
|
pip install git+https://github.com/shumingma/infinibatch.git
|
||||||
pip install iopath
|
pip install iopath
|
||||||
pip install --upgrade numpy
|
pip install numpy==1.23.0
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example: BERT Pretraining
|
## Example: BERT Pretraining
|
||||||
|
@ -166,7 +166,7 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH
|
||||||
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||||
--moe-eval-capacity-token-fraction -1.0 \
|
--moe-eval-capacity-token-fraction -1.0 \
|
||||||
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||||
--use-xmoe
|
--use-xmoe --pad-to-max-length
|
||||||
```
|
```
|
||||||
|
|
||||||
## Example: GPT Pretraining
|
## Example: GPT Pretraining
|
||||||
|
|
|
@ -130,6 +130,9 @@ class BertConfig(FairseqDataclass):
|
||||||
tpu: bool = II("common.tpu")
|
tpu: bool = II("common.tpu")
|
||||||
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
|
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
|
||||||
max_rel_pos: int = field(default=0, metadata={"help": ""})
|
max_rel_pos: int = field(default=0, metadata={"help": ""})
|
||||||
|
use_xmoe: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
moe_freq: int = field(
|
moe_freq: int = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
||||||
|
|
|
@ -166,6 +166,9 @@ class MLMLoader(BaseBatchGen):
|
||||||
mlm_target_max_length = max([len(x[1]) for x in batch])
|
mlm_target_max_length = max([len(x[1]) for x in batch])
|
||||||
s2s_source_max_length = max([len(x[2]) for x in batch])
|
s2s_source_max_length = max([len(x[2]) for x in batch])
|
||||||
s2s_target_max_length = max([len(x[3]) for x in batch])
|
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||||
|
if self.args.pad_to_max_length:
|
||||||
|
mlm_source_max_length = self.args.tokens_per_sample
|
||||||
|
mlm_target_max_length = self.args.tokens_per_sample
|
||||||
|
|
||||||
mlm_source_ids = np.full(
|
mlm_source_ids = np.full(
|
||||||
shape=(batch_size, mlm_source_max_length),
|
shape=(batch_size, mlm_source_max_length),
|
||||||
|
|
|
@ -117,6 +117,9 @@ class PretrainingConfig(FairseqDataclass):
|
||||||
default="",
|
default="",
|
||||||
metadata={"help": ""},
|
metadata={"help": ""},
|
||||||
)
|
)
|
||||||
|
pad_to_max_length: bool = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_task("pretraining", dataclass=PretrainingConfig)
|
@register_task("pretraining", dataclass=PretrainingConfig)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user