Bert MoE
This commit is contained in:
parent
0cb9695501
commit
20c1e6c611
|
@ -70,7 +70,7 @@ You can quickly get started with our processed vocabulary files: [sentencepiece.
|
||||||
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
|
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### Training Command
|
### Dense Model
|
||||||
```bash
|
```bash
|
||||||
cd examples/fairseq/
|
cd examples/fairseq/
|
||||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
||||||
|
@ -117,6 +117,58 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH
|
||||||
--deepnorm
|
--deepnorm
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Sparse (MoE) Model
|
||||||
|
```bash
|
||||||
|
cd examples/fairseq/
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
||||||
|
--task pretraining \
|
||||||
|
--tokens-per-sample 512 \
|
||||||
|
--mask-prob 0.15 \
|
||||||
|
--span-length 3.0 \
|
||||||
|
--leave-unmasked-prob 0.0 \
|
||||||
|
--random-token-prob 0.0 \
|
||||||
|
--criterion masked_lm \
|
||||||
|
--arch mlm_base \
|
||||||
|
--share-encoder-input-output-embed \
|
||||||
|
--required-batch-size-multiple 8 \
|
||||||
|
--spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
|
||||||
|
--dict-file ${PATH_TO_DATA}/dict.txt \
|
||||||
|
--optimizer adam \
|
||||||
|
--adam-betas '(0.9,0.98)' \
|
||||||
|
--adam-eps 1e-6 \
|
||||||
|
--clip-norm 2.0 \
|
||||||
|
--lr-scheduler polynomial_decay \
|
||||||
|
--lr 0.0005 \
|
||||||
|
--warmup-updates 10000 \
|
||||||
|
--total-num-update 125000 \
|
||||||
|
--max-update 125000 \
|
||||||
|
--max-sentences 32 \
|
||||||
|
--update-freq 1 \
|
||||||
|
--log-format simple \
|
||||||
|
--log-interval 100 \
|
||||||
|
--disable-validation \
|
||||||
|
--save-interval-updates 5000 \
|
||||||
|
--no-epoch-checkpoints \
|
||||||
|
--fp16 \
|
||||||
|
--fp16-init-scale 4 \
|
||||||
|
--fp16-scale-window 256 \
|
||||||
|
--min-loss-scale 0.0001 \
|
||||||
|
--seed 1 \
|
||||||
|
--save-dir ${PATH_TO_CKPT} \
|
||||||
|
--ddp-backend=no_c10d \
|
||||||
|
--distributed-no-spawn \
|
||||||
|
--reset-dataloader \
|
||||||
|
--batch-read-ahead 10000 \
|
||||||
|
--rel-pos-buckets 32 \
|
||||||
|
--max-rel-pos 128 \
|
||||||
|
--deepnorm \
|
||||||
|
--moe-expert-count 64 --moe-freq 2 \
|
||||||
|
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||||
|
--moe-eval-capacity-token-fraction -1.0 \
|
||||||
|
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||||
|
--use-xmoe
|
||||||
|
```
|
||||||
|
|
||||||
## Example: GPT Pretraining
|
## Example: GPT Pretraining
|
||||||
|
|
||||||
### Data Format
|
### Data Format
|
||||||
|
|
|
@ -350,6 +350,18 @@ class BertModel(BaseFairseqModel):
|
||||||
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
||||||
state_dict[prefix + "classification_heads." + k] = v
|
state_dict[prefix + "classification_heads." + k] = v
|
||||||
|
|
||||||
|
def get_normalized_probs_scriptable(
|
||||||
|
self,
|
||||||
|
net_output,
|
||||||
|
log_probs,
|
||||||
|
sample = None,
|
||||||
|
):
|
||||||
|
logits = net_output[0]
|
||||||
|
if log_probs:
|
||||||
|
return utils.log_softmax(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
return utils.softmax(logits, dim=-1)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src_tokens=None,
|
src_tokens=None,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user