This commit is contained in:
shumingma 2023-03-02 02:54:19 -08:00
parent 0cb9695501
commit 20c1e6c611
2 changed files with 65 additions and 1 deletions

View File

@ -70,7 +70,7 @@ You can quickly get started with our processed vocabulary files: [sentencepiece.
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
```
### Training Command
### Dense Model
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
@ -117,6 +117,58 @@ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH
--deepnorm
```
### Sparse (MoE) Model
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
--task pretraining \
--tokens-per-sample 512 \
--mask-prob 0.15 \
--span-length 3.0 \
--leave-unmasked-prob 0.0 \
--random-token-prob 0.0 \
--criterion masked_lm \
--arch mlm_base \
--share-encoder-input-output-embed \
--required-batch-size-multiple 8 \
--spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
--dict-file ${PATH_TO_DATA}/dict.txt \
--optimizer adam \
--adam-betas '(0.9,0.98)' \
--adam-eps 1e-6 \
--clip-norm 2.0 \
--lr-scheduler polynomial_decay \
--lr 0.0005 \
--warmup-updates 10000 \
--total-num-update 125000 \
--max-update 125000 \
--max-sentences 32 \
--update-freq 1 \
--log-format simple \
--log-interval 100 \
--disable-validation \
--save-interval-updates 5000 \
--no-epoch-checkpoints \
--fp16 \
--fp16-init-scale 4 \
--fp16-scale-window 256 \
--min-loss-scale 0.0001 \
--seed 1 \
--save-dir ${PATH_TO_CKPT} \
--ddp-backend=no_c10d \
--distributed-no-spawn \
--reset-dataloader \
--batch-read-ahead 10000 \
--rel-pos-buckets 32 \
--max-rel-pos 128 \
--deepnorm \
--moe-expert-count 64 --moe-freq 2 \
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
--moe-eval-capacity-token-fraction -1.0 \
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
--use-xmoe
```
## Example: GPT Pretraining
### Data Format

View File

@ -349,6 +349,18 @@ class BertModel(BaseFairseqModel):
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
def get_normalized_probs_scriptable(
self,
net_output,
log_probs,
sample = None,
):
logits = net_output[0]
if log_probs:
return utils.log_softmax(logits, dim=-1)
else:
return utils.softmax(logits, dim=-1)
def forward(
self,