Bert MoE
This commit is contained in:
parent
0cb9695501
commit
20c1e6c611
|
@ -70,7 +70,7 @@ You can quickly get started with our processed vocabulary files: [sentencepiece.
|
|||
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
|
||||
```
|
||||
|
||||
### Training Command
|
||||
### Dense Model
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
||||
|
@ -117,6 +117,58 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH
|
|||
--deepnorm
|
||||
```
|
||||
|
||||
### Sparse (MoE) Model
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
||||
--task pretraining \
|
||||
--tokens-per-sample 512 \
|
||||
--mask-prob 0.15 \
|
||||
--span-length 3.0 \
|
||||
--leave-unmasked-prob 0.0 \
|
||||
--random-token-prob 0.0 \
|
||||
--criterion masked_lm \
|
||||
--arch mlm_base \
|
||||
--share-encoder-input-output-embed \
|
||||
--required-batch-size-multiple 8 \
|
||||
--spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
|
||||
--dict-file ${PATH_TO_DATA}/dict.txt \
|
||||
--optimizer adam \
|
||||
--adam-betas '(0.9,0.98)' \
|
||||
--adam-eps 1e-6 \
|
||||
--clip-norm 2.0 \
|
||||
--lr-scheduler polynomial_decay \
|
||||
--lr 0.0005 \
|
||||
--warmup-updates 10000 \
|
||||
--total-num-update 125000 \
|
||||
--max-update 125000 \
|
||||
--max-sentences 32 \
|
||||
--update-freq 1 \
|
||||
--log-format simple \
|
||||
--log-interval 100 \
|
||||
--disable-validation \
|
||||
--save-interval-updates 5000 \
|
||||
--no-epoch-checkpoints \
|
||||
--fp16 \
|
||||
--fp16-init-scale 4 \
|
||||
--fp16-scale-window 256 \
|
||||
--min-loss-scale 0.0001 \
|
||||
--seed 1 \
|
||||
--save-dir ${PATH_TO_CKPT} \
|
||||
--ddp-backend=no_c10d \
|
||||
--distributed-no-spawn \
|
||||
--reset-dataloader \
|
||||
--batch-read-ahead 10000 \
|
||||
--rel-pos-buckets 32 \
|
||||
--max-rel-pos 128 \
|
||||
--deepnorm \
|
||||
--moe-expert-count 64 --moe-freq 2 \
|
||||
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||
--moe-eval-capacity-token-fraction -1.0 \
|
||||
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||
--use-xmoe
|
||||
```
|
||||
|
||||
## Example: GPT Pretraining
|
||||
|
||||
### Data Format
|
||||
|
|
|
@ -349,6 +349,18 @@ class BertModel(BaseFairseqModel):
|
|||
if prefix + "classification_heads." + k not in state_dict:
|
||||
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
||||
state_dict[prefix + "classification_heads." + k] = v
|
||||
|
||||
def get_normalized_probs_scriptable(
|
||||
self,
|
||||
net_output,
|
||||
log_probs,
|
||||
sample = None,
|
||||
):
|
||||
logits = net_output[0]
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits, dim=-1)
|
||||
else:
|
||||
return utils.softmax(logits, dim=-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user