diff --git a/examples/fairseq/README.md b/examples/fairseq/README.md index c663abd..b87c4bc 100644 --- a/examples/fairseq/README.md +++ b/examples/fairseq/README.md @@ -70,7 +70,7 @@ You can quickly get started with our processed vocabulary files: [sentencepiece. spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt ``` -### Training Command +### Dense Model ```bash cd examples/fairseq/ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \ @@ -117,6 +117,58 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH --deepnorm ``` +### Sparse (MoE) Model +```bash +cd examples/fairseq/ +python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \ + --task pretraining \ + --tokens-per-sample 512 \ + --mask-prob 0.15 \ + --span-length 3.0 \ + --leave-unmasked-prob 0.0 \ + --random-token-prob 0.0 \ + --criterion masked_lm \ + --arch mlm_base \ + --share-encoder-input-output-embed \ + --required-batch-size-multiple 8 \ + --spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \ + --dict-file ${PATH_TO_DATA}/dict.txt \ + --optimizer adam \ + --adam-betas '(0.9,0.98)' \ + --adam-eps 1e-6 \ + --clip-norm 2.0 \ + --lr-scheduler polynomial_decay \ + --lr 0.0005 \ + --warmup-updates 10000 \ + --total-num-update 125000 \ + --max-update 125000 \ + --max-sentences 32 \ + --update-freq 1 \ + --log-format simple \ + --log-interval 100 \ + --disable-validation \ + --save-interval-updates 5000 \ + --no-epoch-checkpoints \ + --fp16 \ + --fp16-init-scale 4 \ + --fp16-scale-window 256 \ + --min-loss-scale 0.0001 \ + --seed 1 \ + --save-dir ${PATH_TO_CKPT} \ + --ddp-backend=no_c10d \ + --distributed-no-spawn \ + --reset-dataloader \ + --batch-read-ahead 10000 \ + --rel-pos-buckets 32 \ + --max-rel-pos 128 \ + --deepnorm \ + --moe-expert-count 64 --moe-freq 2 \ + --moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \ + --moe-eval-capacity-token-fraction -1.0 \ + --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ + --use-xmoe +``` + ## Example: GPT Pretraining ### Data Format diff --git a/examples/fairseq/models/bert.py b/examples/fairseq/models/bert.py index 8327484..f804973 100644 --- a/examples/fairseq/models/bert.py +++ b/examples/fairseq/models/bert.py @@ -349,6 +349,18 @@ class BertModel(BaseFairseqModel): if prefix + "classification_heads." + k not in state_dict: logger.info("Overwriting " + prefix + "classification_heads." + k) state_dict[prefix + "classification_heads." + k] = v + + def get_normalized_probs_scriptable( + self, + net_output, + log_probs, + sample = None, + ): + logits = net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1) + else: + return utils.softmax(logits, dim=-1) def forward( self,