torchscale released
This commit is contained in:
parent
41f6ee5687
commit
ede048831f
162
.gitignore
vendored
162
.gitignore
vendored
|
@ -348,3 +348,165 @@ MigrationBackup/
|
|||
|
||||
# Ionide (cross platform F# VS Code tools) working folder
|
||||
.ionide/
|
||||
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
173
README.md
173
README.md
|
@ -1,11 +1,176 @@
|
|||
# OpenScale - Transformers at (any) Scale
|
||||
# TorchScale - A Library for Transformers at (Any) Scale
|
||||
|
||||
Fundamental research to improve modeling generality and capability, as well as training stability and efficiency of scaling Transformers.
|
||||
<p>
|
||||
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
||||
<a href="https://pypi.org/project/torchscale"><img alt="MIT License" src="https://badge.fury.io/py/torchscale.svg" /></a>
|
||||
</p>
|
||||
|
||||
TorchScale is a PyTorch library that allows researchers and developeres to scale up Transformers efficiently and effectively.
|
||||
It has the implemetention of fundamental research to improve modeling generality and capability, as well as training stability and efficiency of scaling Transformers.
|
||||
|
||||
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
|
||||
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423)
|
||||
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
|
||||
|
||||
## News
|
||||
|
||||
- November, 2022: TorchScale 0.1.1 released
|
||||
|
||||
## Installation
|
||||
|
||||
To install:
|
||||
```
|
||||
pip install torchscale
|
||||
```
|
||||
|
||||
Alternatively, you can develop it locally:
|
||||
```
|
||||
git clone https://github.com/microsoft/torchscale.git
|
||||
cd torchscale
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
|
||||
|
||||
```python
|
||||
>>> from torchscale.architecture.config import EncoderConfig
|
||||
>>> from torchscale.architecture.encoder import Encoder
|
||||
|
||||
>>> config = EncoderConfig(vocab_size=64000)
|
||||
>>> model = Encoder(config)
|
||||
|
||||
>>> print(model)
|
||||
```
|
||||
|
||||
We also support the `Decoder` architecture and the `EncoderDecoder` architecture:
|
||||
|
||||
```python
|
||||
# Creating a decoder model
|
||||
>>> from torchscale.architecture.config import DecoderConfig
|
||||
>>> from torchscale.architecture.decoder import Decoder
|
||||
|
||||
>>> config = DecoderConfig(vocab_size=64000)
|
||||
>>> decoder = Decoder(config)
|
||||
>>> print(decoder)
|
||||
|
||||
# Creating a encoder-decoder model
|
||||
>>> from torchscale.architecture.config import EncoderDecoderConfig
|
||||
>>> from torchscale.architecture.encoder_decoder import EncoderDecoder
|
||||
|
||||
>>> config = EncoderDecoderConfig(vocab_size=64000)
|
||||
>>> encdec = EncoderDecoder(config)
|
||||
>>> print(encdec)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
We have the examples of how to use TorchScale in the following scenarios/tasks:
|
||||
|
||||
- Language
|
||||
|
||||
* [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining)
|
||||
|
||||
* [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation)
|
||||
|
||||
* [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining)
|
||||
|
||||
- Vision
|
||||
|
||||
* ViT/BEiT [In progress]
|
||||
|
||||
- Speech
|
||||
|
||||
- Multimodal
|
||||
|
||||
* [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
|
||||
|
||||
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
|
||||
|
||||
## Results
|
||||
|
||||
### Stability Evaluation
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/convergence.png" width="800"/>
|
||||
</p>
|
||||
|
||||
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
|
||||
|
||||
### Scaling-up Experiments
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/scaling_curve.png" width="800"/>
|
||||
</p>
|
||||
|
||||
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository.
|
||||
|
||||
## Citations
|
||||
|
||||
If you find this repository useful, please consider citing our work:
|
||||
|
||||
```
|
||||
@article{deepnet,
|
||||
author = {Hongyu Wang and
|
||||
Shuming Ma and
|
||||
Li Dong and
|
||||
Shaohan Huang and
|
||||
Dongdong Zhang and
|
||||
Furu Wei},
|
||||
title = {{DeepNet}: Scaling Transformers to 1,000 Layers},
|
||||
journal = {CoRR},
|
||||
volume = {abs/2203.00555},
|
||||
year = {2022},
|
||||
}
|
||||
```
|
||||
|
||||
```
|
||||
@article{magneto,
|
||||
author = {Hongyu Wang and
|
||||
Shuming Ma and
|
||||
Shaohan Huang and
|
||||
Li Dong and
|
||||
Wenhui Wang and
|
||||
Zhiliang Peng and
|
||||
Yu Wu and
|
||||
Payal Bajaj and
|
||||
Saksham Singhal and
|
||||
Alon Benhaim and
|
||||
Barun Patra and
|
||||
Zhun Liu and
|
||||
Vishrav Chaudhary and
|
||||
Xia Song and
|
||||
Furu Wei},
|
||||
title = {Foundation Transformers},
|
||||
journal = {CoRR},
|
||||
volume = {abs/2210.06423},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```
|
||||
@article{xmoe,
|
||||
author = {Zewen Chi and
|
||||
Li Dong and
|
||||
Shaohan Huang and
|
||||
Damai Dai and
|
||||
Shuming Ma and
|
||||
Barun Patra and
|
||||
Saksham Singhal and
|
||||
Payal Bajaj and
|
||||
Xia Song and
|
||||
Furu Wei},
|
||||
title = {On the Representation Collapse of Sparse Mixture of Experts},
|
||||
journal = {CoRR},
|
||||
volume = {abs/2204.09179},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -19,7 +184,7 @@ provided by the bot. You will only need to do this once across all repos using o
|
|||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Trademarks
|
||||
|
||||
|
@ -27,4 +192,4 @@ This project may contain trademarks or logos for projects, products, or services
|
|||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
50
SUPPORT.md
50
SUPPORT.md
|
@ -1,25 +1,25 @@
|
|||
# TODO: The maintainer of this repo has not yet edited this file
|
||||
|
||||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
||||
|
||||
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
||||
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
|
||||
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
|
||||
|
||||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
||||
|
||||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
||||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
||||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
||||
# TODO: The maintainer of this repo has not yet edited this file
|
||||
|
||||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
||||
|
||||
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
||||
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
|
||||
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
|
||||
|
||||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
||||
|
||||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
||||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
||||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
||||
|
|
BIN
assets/convergence.png
Normal file
BIN
assets/convergence.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 98 KiB |
BIN
assets/scaling_curve.png
Normal file
BIN
assets/scaling_curve.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
233
examples/fairseq/README.md
Normal file
233
examples/fairseq/README.md
Normal file
|
@ -0,0 +1,233 @@
|
|||
# Example: Integration with FairSeq
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install the repo as a package:
|
||||
git clone https://github.com/msranlp/torchscale.git
|
||||
cd torchscale
|
||||
pip install -e .
|
||||
pip install git+https://github.com/shumingma/fairseq.git@moe
|
||||
pip install git+https://github.com/shumingma/infinibatch.git
|
||||
pip install iopath
|
||||
pip install --upgrade numpy
|
||||
```
|
||||
|
||||
## Example: BERT Pretraining
|
||||
|
||||
### Data Format
|
||||
|
||||
We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files.
|
||||
|
||||
The overall data directory should be organized as follows:
|
||||
```
|
||||
Data/
|
||||
├── json/
|
||||
│ ├── train.json
|
||||
│ └── valid.json
|
||||
├── shard/
|
||||
│ ├── train/
|
||||
│ │ ├── 00000.txt
|
||||
│ │ ├── 00001.txt
|
||||
│ │ └── ...
|
||||
│ └── valid/
|
||||
│ ├── 00000.txt
|
||||
│ ├── 00001.txt
|
||||
│ └── ...
|
||||
├── dict.txt
|
||||
└── sentencepiece.bpe.model
|
||||
```
|
||||
|
||||
We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line.
|
||||
```
|
||||
Document 1 Line 1
|
||||
Document 1 Line 2
|
||||
Document 1 Line 3
|
||||
|
||||
Document 2 Line 1
|
||||
Document 2 Line 2
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Also, the JSON file should be in the format like this:
|
||||
```
|
||||
[
|
||||
{
|
||||
"source": [
|
||||
"shard/train/00000.txt",
|
||||
"shard/train/00001.txt",
|
||||
...
|
||||
],
|
||||
"source_lang": "en",
|
||||
"weight": 1.0
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Training Command
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
||||
--task pretraining \
|
||||
--tokens-per-sample 512 \
|
||||
--mask-prob 0.15 \
|
||||
--span-length 3.0 \
|
||||
--leave-unmasked-prob 0.0 \
|
||||
--random-token-prob 0.0 \
|
||||
--criterion masked_lm \
|
||||
--arch mlm_base \
|
||||
--share-encoder-input-output-embed \
|
||||
--required-batch-size-multiple 8 \
|
||||
--spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
|
||||
--dict-file ${PATH_TO_DATA}/dict.txt \
|
||||
--optimizer adam \
|
||||
--adam-betas '(0.9,0.98)' \
|
||||
--adam-eps 1e-6 \
|
||||
--clip-norm 2.0 \
|
||||
--lr-scheduler polynomial_decay \
|
||||
--lr 0.0005 \
|
||||
--warmup-updates 10000 \
|
||||
--total-num-update 125000 \
|
||||
--max-update 125000 \
|
||||
--max-sentences 32 \
|
||||
--update-freq 1 \
|
||||
--log-format simple \
|
||||
--log-interval 100 \
|
||||
--disable-validation \
|
||||
--save-interval-updates 5000 \
|
||||
--no-epoch-checkpoints \
|
||||
--fp16 \
|
||||
--fp16-init-scale 4 \
|
||||
--fp16-scale-window 256 \
|
||||
--min-loss-scale 0.0001 \
|
||||
--seed 1 \
|
||||
--save-dir ${PATH_TO_CKPT} \
|
||||
--ddp-backend=no_c10d \
|
||||
--distributed-no-spawn \
|
||||
--reset-dataloader \
|
||||
--batch-read-ahead 10000 \
|
||||
--rel-pos-buckets 32 \
|
||||
--max-rel-pos 128 \
|
||||
--deepnorm
|
||||
```
|
||||
|
||||
## Example: GPT Pretraining
|
||||
|
||||
### Data Format
|
||||
|
||||
We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data).
|
||||
|
||||
### Dense Model
|
||||
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||
${PATH_TO_DATA} \
|
||||
--num-workers 2 \
|
||||
--activation-fn gelu \
|
||||
--share-decoder-input-output-embed \
|
||||
--validate-interval-updates 1000 \
|
||||
--save-interval-updates 1000 \
|
||||
--no-epoch-checkpoints \
|
||||
--memory-efficient-fp16 \
|
||||
--fp16-init-scale 4 \
|
||||
--arch lm_base \
|
||||
--task language_modeling \
|
||||
--sample-break-mode none \
|
||||
--tokens-per-sample 128 \
|
||||
--optimizer adam --adam-betas "(0.9, 0.98)" \
|
||||
--adam-eps 1e-08 \
|
||||
--clip-norm 0.0 \
|
||||
--lr 5e-4 \
|
||||
--lr-scheduler polynomial_decay \
|
||||
--warmup-updates 750 \
|
||||
--dropout 0.1 \
|
||||
--attention-dropout 0.1 \
|
||||
--weight-decay 0.01 \
|
||||
--batch-size 4 \
|
||||
--update-freq 1 \
|
||||
--required-batch-size-multiple 1 \
|
||||
--total-num-update 50000 \
|
||||
--max-update 50000 \
|
||||
--seed 1 \
|
||||
--ddp-backend=c10d
|
||||
```
|
||||
|
||||
### Sparse (MoE) Model
|
||||
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||
${PATH_TO_DATA} \
|
||||
--num-workers 2 \
|
||||
--activation-fn gelu \
|
||||
--share-decoder-input-output-embed \
|
||||
--validate-interval-updates 1000 \
|
||||
--save-interval-updates 1000 \
|
||||
--no-epoch-checkpoints \
|
||||
--memory-efficient-fp16 \
|
||||
--fp16-init-scale 4 \
|
||||
--arch lm_base \
|
||||
--task language_modeling \
|
||||
--sample-break-mode none \
|
||||
--tokens-per-sample 128 \
|
||||
--optimizer adam --adam-betas "(0.9, 0.98)" \
|
||||
--adam-eps 1e-08 \
|
||||
--clip-norm 0.0 \
|
||||
--lr 5e-4 \
|
||||
--lr-scheduler polynomial_decay \
|
||||
--warmup-updates 750 \
|
||||
--dropout 0.1 \
|
||||
--attention-dropout 0.1 \
|
||||
--weight-decay 0.01 \
|
||||
--batch-size 4 \
|
||||
--update-freq 1 \
|
||||
--required-batch-size-multiple 1 \
|
||||
--total-num-update 50000 \
|
||||
--max-update 50000 \
|
||||
--seed 1 \
|
||||
--ddp-backend=no_c10d \
|
||||
--moe-expert-count 2 --moe-freq 2 \
|
||||
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||
--moe-eval-capacity-token-fraction -1.0 \
|
||||
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||
--use-xmoe
|
||||
```
|
||||
|
||||
## Example: Machine Translation
|
||||
|
||||
### Data Format
|
||||
|
||||
We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data.
|
||||
|
||||
### Dense Model
|
||||
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||
${PATH_TO_DATA} \
|
||||
--arch mt_base --share-decoder-input-output-embed \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
||||
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
||||
--dropout 0.3 --weight-decay 0.0001 \
|
||||
--max-tokens 4096 --fp16
|
||||
```
|
||||
|
||||
### Sparse (MoE) Model
|
||||
|
||||
```bash
|
||||
cd examples/fairseq/
|
||||
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||
${PATH_TO_DATA} \
|
||||
--arch mt_base --share-decoder-input-output-embed \
|
||||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
||||
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
||||
--dropout 0.3 --weight-decay 0.0001 \
|
||||
--moe-expert-count 2 --moe-freq 2 \
|
||||
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||
--moe-eval-capacity-token-fraction -1.0 \
|
||||
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||
--use-xmoe \
|
||||
--max-tokens 4096 --fp16
|
||||
```
|
0
examples/fairseq/__init__.py
Normal file
0
examples/fairseq/__init__.py
Normal file
7
examples/fairseq/generate.py
Normal file
7
examples/fairseq/generate.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.generate import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
7
examples/fairseq/interactive.py
Normal file
7
examples/fairseq/interactive.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.interactive import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
33
examples/fairseq/models/__init__.py
Normal file
33
examples/fairseq/models/__init__.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
|
||||
MODEL_REGISTRY = {}
|
||||
MODEL_DATACLASS_REGISTRY = {}
|
||||
ARCH_MODEL_REGISTRY = {}
|
||||
ARCH_MODEL_NAME_REGISTRY = {}
|
||||
ARCH_MODEL_INV_REGISTRY = {}
|
||||
ARCH_CONFIG_REGISTRY = {}
|
||||
|
||||
# automatically import any Python files in the models/ directory
|
||||
models_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(models_dir):
|
||||
path = os.path.join(models_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("models." + model_name)
|
||||
|
||||
# extra `model_parser` for sphinx
|
||||
if model_name in MODEL_REGISTRY:
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
group_archs = parser.add_argument_group("Named architectures")
|
||||
group_archs.add_argument(
|
||||
"--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
|
||||
)
|
||||
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||
MODEL_REGISTRY[model_name].add_args(group_args)
|
||||
globals()[model_name + "_parser"] = parser
|
459
examples/fairseq/models/bert.py
Normal file
459
examples/fairseq/models/bert.py
Normal file
|
@ -0,0 +1,459 @@
|
|||
import math
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import fsdp_wrap
|
||||
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models.transformer import (
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||
)
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from fairseq.models.squad import SQuADHead
|
||||
from torch import Tensor
|
||||
from omegaconf import II
|
||||
from .machine_translation import MTEncoder as Encoder
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class BertConfig(FairseqDataclass):
|
||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||
default="relu", metadata={"help": "activation function to use"}
|
||||
)
|
||||
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
||||
attention_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability for attention weights"}
|
||||
)
|
||||
activation_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
||||
)
|
||||
encoder_embed_dim: int = field(
|
||||
default=512, metadata={"help": "encoder embedding dimension"}
|
||||
)
|
||||
encoder_output_dim: int = field(
|
||||
default=512, metadata={"help": "encoder output dimension"}
|
||||
)
|
||||
encoder_input_dim: int = field(
|
||||
default=512, metadata={"help": "encoder input dimension"}
|
||||
)
|
||||
encoder_ffn_embed_dim: int = field(
|
||||
default=2048, metadata={"help": "encoder embedding dimension for FFN"}
|
||||
)
|
||||
encoder_layers: int = field(default=6, metadata={"help": "num encoder layers"})
|
||||
encoder_attention_heads: int = field(
|
||||
default=8, metadata={"help": "num encoder attention heads"}
|
||||
)
|
||||
encoder_normalize_before: bool = field(
|
||||
default=False, metadata={"help": "apply layernorm before each encoder block"}
|
||||
)
|
||||
no_encoder_final_norm: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "don't add an extra layernorm after the last encoder block"},
|
||||
)
|
||||
no_token_positional_embeddings: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, disables positional embeddings (outside self attention)"
|
||||
},
|
||||
)
|
||||
share_encoder_input_output_embed: bool = field(
|
||||
default=False, metadata={"help": "share encoder input and output embeddings"}
|
||||
)
|
||||
encoder_learned_pos: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use learned positional embeddings in the encoder"},
|
||||
)
|
||||
layernorm_embedding: bool = field(
|
||||
default=False, metadata={"help": "add layernorm to embedding"}
|
||||
)
|
||||
no_scale_embedding: bool = field(
|
||||
default=False, metadata={"help": "if True, dont scale embeddings"}
|
||||
)
|
||||
checkpoint_activations: bool = field(
|
||||
default=False, metadata={"help": "checkpoint activations at each layer"}
|
||||
)
|
||||
offload_activations: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "move checkpointed activations to CPU after they are used."},
|
||||
)
|
||||
# config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
||||
encoder_layerdrop: float = field(
|
||||
default=0.0, metadata={"help": "LayerDrop probability for encoder"}
|
||||
)
|
||||
encoder_layers_to_keep: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "which layers to *keep* when pruning as a comma-separated list"
|
||||
},
|
||||
)
|
||||
# config for Fully Sharded Data Parallel (FSDP) training
|
||||
min_params_to_wrap: int = field(
|
||||
default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||
metadata={
|
||||
"help": (
|
||||
"minimum number of params for a layer to be wrapped with FSDP() when "
|
||||
"training with --ddp-backend=fully_sharded. Smaller values will "
|
||||
"improve memory efficiency, but may make torch.distributed "
|
||||
"communication less efficient due to smaller input sizes. This option "
|
||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
)
|
||||
}
|
||||
)
|
||||
max_source_positions: int = field(
|
||||
default=1024, metadata={"help": "max source positions"}
|
||||
)
|
||||
pooler_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||
default="relu", metadata={"help": "activation function to use for pooler layer"}
|
||||
)
|
||||
pooler_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
|
||||
)
|
||||
# options from other parts of the config
|
||||
# add_bos_token: bool = II("task.add_bos_token")
|
||||
# tokens_per_sample: int = II("task.tokens_per_sample")
|
||||
tpu: bool = II("common.tpu")
|
||||
rel_pos_buckets: int = field(
|
||||
default=0, metadata={"help": ""}
|
||||
)
|
||||
max_rel_pos: int = field(
|
||||
default=0, metadata={"help": ""}
|
||||
)
|
||||
moe_freq: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Frequency at which we insert MoE Transformer layers"
|
||||
},
|
||||
)
|
||||
moe_expert_count: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Number of experts in each MoE Layer"
|
||||
}
|
||||
)
|
||||
moe_gating_use_fp32: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use FP32 computations in MoE top2 gating function"
|
||||
}
|
||||
)
|
||||
moe_second_expert_policy: str = field(
|
||||
default='sampling',
|
||||
metadata={
|
||||
"help": "policy for second expert, options: all/sampling/random"
|
||||
}
|
||||
)
|
||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
||||
}
|
||||
)
|
||||
moe_expert_ffn_dim: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "MoE expert FFN dimension"
|
||||
}
|
||||
)
|
||||
moe_top1_expert: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use top1 gate instead of top2"
|
||||
}
|
||||
)
|
||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||
default=0.25,
|
||||
metadata={
|
||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
}
|
||||
)
|
||||
moe_normalize_expert_grad: Optional[str] = field(
|
||||
default='world_size',
|
||||
metadata={
|
||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||
}
|
||||
)
|
||||
record_a2a_perf_stats: Optional[bool] = field(
|
||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||
)
|
||||
dummy_a2a: Optional[bool] = field(
|
||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
)
|
||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||
)
|
||||
ddp_rank: int = II("distributed_training.distributed_rank")
|
||||
deepnorm: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
subln: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
@register_model("mlm", dataclass=BertConfig)
|
||||
class BertModel(BaseFairseqModel):
|
||||
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.encoder = encoder
|
||||
self.padding_idx = self.encoder.embed_tokens.padding_idx
|
||||
self.classification_heads = nn.ModuleDict()
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
args.max_source_positions = getattr(
|
||||
args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS
|
||||
)
|
||||
|
||||
embed_tokens = cls.build_embedding(
|
||||
args, task.dictionary, args.encoder_embed_dim
|
||||
)
|
||||
|
||||
embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_source_positions,
|
||||
args.encoder_embed_dim,
|
||||
task.dictionary.pad(),
|
||||
learned=args.encoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
lm_head = cls.build_lm_head(
|
||||
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
|
||||
)
|
||||
|
||||
config = EncoderConfig()
|
||||
config.override(args)
|
||||
|
||||
encoder = Encoder(
|
||||
config,
|
||||
embed_tokens=embed_tokens,
|
||||
embed_positions=embed_positions,
|
||||
output_projection=lm_head,
|
||||
is_encoder_decoder=False,
|
||||
dictionary=task.dictionary,
|
||||
)
|
||||
|
||||
return cls(args, encoder)
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
|
||||
return embed_tokens
|
||||
|
||||
@classmethod
|
||||
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
|
||||
return LMHead(embed_dim, output_dim, activation_fn, weight)
|
||||
|
||||
def output_layer(self, features, masked_tokens=None):
|
||||
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
||||
|
||||
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
||||
"""Register a classification head."""
|
||||
if name in self.classification_heads:
|
||||
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
||||
prev_inner_dim = self.classification_heads[name].dense.out_features
|
||||
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
||||
logger.warning(
|
||||
're-registering head "{}" with num_classes {} (prev: {}) '
|
||||
'and inner_dim {} (prev: {})'.format(
|
||||
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
||||
)
|
||||
)
|
||||
self.classification_heads[name] = ClassificationHead(
|
||||
self.args.encoder_embed_dim,
|
||||
inner_dim or self.args.encoder_embed_dim,
|
||||
num_classes,
|
||||
self.args.pooler_activation_fn,
|
||||
self.args.pooler_dropout,
|
||||
)
|
||||
|
||||
def register_question_answering_head(self, name, num_classes=None):
|
||||
self.classification_heads[name] = SQuADHead(
|
||||
self.args.encoder_embed_dim,
|
||||
)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + '.' if name != '' else ''
|
||||
|
||||
# upgrade children modules
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
|
||||
# Handle new classification heads present in the state dict.
|
||||
current_head_names = (
|
||||
[] if not hasattr(self, 'classification_heads')
|
||||
else self.classification_heads.keys()
|
||||
)
|
||||
keys_to_delete = []
|
||||
for k in state_dict.keys():
|
||||
if not k.startswith(prefix + 'classification_heads.'):
|
||||
continue
|
||||
|
||||
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
|
||||
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
|
||||
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
|
||||
|
||||
if getattr(self.args, 'load_checkpoint_heads', False):
|
||||
if head_name not in current_head_names:
|
||||
self.register_classification_head(head_name, num_classes, inner_dim)
|
||||
else:
|
||||
if head_name not in current_head_names:
|
||||
logger.warning(
|
||||
'deleting classification head ({}) from checkpoint '
|
||||
'not present in current model: {}'.format(head_name, k)
|
||||
)
|
||||
keys_to_delete.append(k)
|
||||
elif (
|
||||
num_classes != self.classification_heads[head_name].out_proj.out_features
|
||||
or inner_dim != self.classification_heads[head_name].dense.out_features
|
||||
):
|
||||
logger.warning(
|
||||
'deleting classification head ({}) from checkpoint '
|
||||
'with different dimensions than current model: {}'.format(head_name, k)
|
||||
)
|
||||
keys_to_delete.append(k)
|
||||
for k in keys_to_delete:
|
||||
del state_dict[k]
|
||||
|
||||
# Copy any newly-added classification heads into the state dict
|
||||
# with their current weights.
|
||||
if hasattr(self, 'classification_heads'):
|
||||
cur_state = self.classification_heads.state_dict()
|
||||
for k, v in cur_state.items():
|
||||
if prefix + 'classification_heads.' + k not in state_dict:
|
||||
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
||||
state_dict[prefix + 'classification_heads.' + k] = v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens=None,
|
||||
features_only=False,
|
||||
return_all_hiddens=False,
|
||||
classification_head_name=None,
|
||||
masked_tokens=None,
|
||||
**kwargs):
|
||||
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
||||
x, extra = encoder_out["encoder_out"], encoder_out
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if classification_head_name is not None:
|
||||
x = self.classification_heads[classification_head_name](x)
|
||||
elif not features_only:
|
||||
x = self.output_layer(x, masked_tokens=masked_tokens)
|
||||
|
||||
return x, extra
|
||||
|
||||
|
||||
class ClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
inner_dim,
|
||||
num_classes,
|
||||
activation_fn,
|
||||
pooler_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(input_dim, inner_dim)
|
||||
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dropout(x)
|
||||
x = self.dense(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class LMHead(nn.Module):
|
||||
"""Head for masked language modeling."""
|
||||
|
||||
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(embed_dim, embed_dim)
|
||||
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
|
||||
if weight is None:
|
||||
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
|
||||
self.weight = weight
|
||||
self.bias = nn.Parameter(torch.zeros(output_dim))
|
||||
|
||||
def forward(self, features, masked_tokens=None, **kwargs):
|
||||
# Only project the masked tokens while training,
|
||||
# saves both memory and computation
|
||||
if masked_tokens is not None:
|
||||
features = features[masked_tokens, :]
|
||||
|
||||
x = self.dense(features)
|
||||
x = self.activation_fn(x)
|
||||
x = self.layer_norm(x)
|
||||
# project back to size of vocabulary with bias
|
||||
x = F.linear(x, self.weight) + self.bias
|
||||
return x
|
||||
|
||||
|
||||
@register_model_architecture("mlm", "mlm_base")
|
||||
def base_unilm_architecture(args):
|
||||
if hasattr(args, "encoder_final_norm"):
|
||||
args.no_encoder_final_norm = not args.encoder_final_norm
|
||||
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
||||
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||
|
||||
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
||||
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||
|
||||
# args.add_bos_token = getattr(args, "add_bos_token", False)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.share_encoder_input_output_embed = getattr(
|
||||
args, "share_encoder_input_output_embed", True
|
||||
)
|
||||
args.encoder_output_dim = getattr(
|
||||
args, "encoder_output_dim", args.encoder_embed_dim
|
||||
)
|
||||
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
|
||||
|
||||
# Model training is not stable without this
|
||||
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
|
||||
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
|
||||
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||
args.offload_activations = getattr(args, "offload_activations", False)
|
||||
if args.offload_activations:
|
||||
args.checkpoint_activations = True
|
357
examples/fairseq/models/language_modeling.py
Normal file
357
examples/fairseq/models/language_modeling.py
Normal file
|
@ -0,0 +1,357 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq import distributed_utils
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import (
|
||||
FairseqIncrementalDecoder,
|
||||
FairseqLanguageModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import (
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
|
||||
)
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
from torchscale.architecture.config import DecoderConfig
|
||||
from omegaconf import II
|
||||
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class LanguageConfig(FairseqDataclass):
|
||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||
default="relu", metadata={"help": "activation function to use"}
|
||||
)
|
||||
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
||||
attention_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability for attention weights"}
|
||||
)
|
||||
activation_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
||||
)
|
||||
relu_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
||||
)
|
||||
decoder_embed_dim: int = field(
|
||||
default=512, metadata={"help": "decoder embedding dimension"}
|
||||
)
|
||||
decoder_output_dim: int = field(
|
||||
default=512, metadata={"help": "decoder output dimension"}
|
||||
)
|
||||
decoder_input_dim: int = field(
|
||||
default=512, metadata={"help": "decoder input dimension"}
|
||||
)
|
||||
decoder_ffn_embed_dim: int = field(
|
||||
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
|
||||
)
|
||||
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
|
||||
decoder_attention_heads: int = field(
|
||||
default=8, metadata={"help": "num decoder attention heads"}
|
||||
)
|
||||
decoder_normalize_before: bool = field(
|
||||
default=False, metadata={"help": "apply layernorm before each decoder block"}
|
||||
)
|
||||
no_token_positional_embeddings: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, disables positional embeddings (outside self attention)"
|
||||
},
|
||||
)
|
||||
share_decoder_input_output_embed: bool = field(
|
||||
default=False, metadata={"help": "share decoder input and output embeddings"}
|
||||
)
|
||||
decoder_learned_pos: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "use learned positional embeddings in the decoder"},
|
||||
)
|
||||
layernorm_embedding: bool = field(
|
||||
default=False, metadata={"help": "add layernorm to embedding"}
|
||||
)
|
||||
no_scale_embedding: bool = field(
|
||||
default=False, metadata={"help": "if True, dont scale embeddings"}
|
||||
)
|
||||
checkpoint_activations: bool = field(
|
||||
default=False, metadata={"help": "checkpoint activations at each layer"}
|
||||
)
|
||||
offload_activations: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "move checkpointed activations to CPU after they are used."},
|
||||
)
|
||||
# config for Fully Sharded Data Parallel (FSDP) training
|
||||
min_params_to_wrap: int = field(
|
||||
default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||
metadata={
|
||||
"help": (
|
||||
"minimum number of params for a layer to be wrapped with FSDP() when "
|
||||
"training with --ddp-backend=fully_sharded. Smaller values will "
|
||||
"improve memory efficiency, but may make torch.distributed "
|
||||
"communication less efficient due to smaller input sizes. This option "
|
||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
)
|
||||
}
|
||||
)
|
||||
moe_freq: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Frequency at which we insert MoE Transformer layers"
|
||||
},
|
||||
)
|
||||
moe_expert_count: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Number of experts in each MoE Layer"
|
||||
}
|
||||
)
|
||||
moe_gating_use_fp32: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use FP32 computations in MoE top2 gating function"
|
||||
}
|
||||
)
|
||||
moe_second_expert_policy: str = field(
|
||||
default='sampling',
|
||||
metadata={
|
||||
"help": "policy for second expert, options: all/sampling/random"
|
||||
}
|
||||
)
|
||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
||||
}
|
||||
)
|
||||
moe_expert_ffn_dim: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "MoE expert FFN dimension"
|
||||
}
|
||||
)
|
||||
moe_top1_expert: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use top1 gate instead of top2"
|
||||
}
|
||||
)
|
||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||
default=0.25,
|
||||
metadata={
|
||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
}
|
||||
)
|
||||
moe_normalize_expert_grad: Optional[str] = field(
|
||||
default='world_size',
|
||||
metadata={
|
||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||
}
|
||||
)
|
||||
record_a2a_perf_stats: Optional[bool] = field(
|
||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||
)
|
||||
dummy_a2a: Optional[bool] = field(
|
||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
)
|
||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||
)
|
||||
use_xmoe: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
|
||||
# options from other parts of the config
|
||||
add_bos_token: bool = II("task.add_bos_token")
|
||||
tokens_per_sample: int = II("task.tokens_per_sample")
|
||||
max_target_positions: Optional[int] = II("task.max_target_positions")
|
||||
tpu: bool = II("common.tpu")
|
||||
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
|
||||
fp16: bool = II("common.fp16")
|
||||
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
|
||||
ddp_backend: str = II("distributed_training.ddp_backend")
|
||||
world_size: int = II("distributed_training.distributed_world_size")
|
||||
distributed_rank: int = II("distributed_training.distributed_rank")
|
||||
ddp_rank: int = II("distributed_training.distributed_rank")
|
||||
deepnorm: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
subln: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
rel_pos_buckets: Optional[int] = field(
|
||||
default=0,
|
||||
)
|
||||
max_rel_pos: Optional[int] = field(
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
@register_model("lm", dataclass=LanguageConfig)
|
||||
class LanguageModel(FairseqLanguageModel):
|
||||
|
||||
def __init__(self, args, decoder):
|
||||
self.args = args
|
||||
super().__init__(decoder)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
|
||||
if getattr(args, "max_target_positions", None) is None:
|
||||
args.max_target_positions = getattr(
|
||||
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
||||
)
|
||||
|
||||
embed_tokens = cls.build_embedding(
|
||||
args, task.source_dictionary, args.decoder_embed_dim
|
||||
)
|
||||
|
||||
embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_target_positions,
|
||||
args.decoder_embed_dim,
|
||||
task.dictionary.pad(),
|
||||
learned=args.decoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
if args.share_decoder_input_output_embed:
|
||||
output_projection = torch.nn.Linear(
|
||||
embed_tokens.weight.shape[1],
|
||||
embed_tokens.weight.shape[0],
|
||||
bias=False,
|
||||
)
|
||||
output_projection.weight = embed_tokens.weight
|
||||
else:
|
||||
output_projection = torch.nn.Linear(
|
||||
decoder_embed_dim, len(task.dictionary), bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(args, 'moe_freq', 0) > 0
|
||||
and (
|
||||
getattr(args, 'fp16', False)
|
||||
and not getattr(args, 'memory_efficient_fp16', False)
|
||||
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
||||
)
|
||||
):
|
||||
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||
|
||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||
|
||||
config = DecoderConfig()
|
||||
config.override(args)
|
||||
|
||||
decoder = LMDecoder(
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
output_projection,
|
||||
is_encoder_decoder=False,
|
||||
dictionary=task.dictionary,
|
||||
)
|
||||
|
||||
return cls(args, decoder)
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
return Embedding(len(dictionary), embed_dim, dictionary.pad())
|
||||
|
||||
|
||||
class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||
|
||||
def forward(self, src_tokens, **kwargs):
|
||||
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
|
||||
|
||||
def max_positions(self):
|
||||
return self.embed_positions.max_positions
|
||||
|
||||
def reorder_incremental_state_scripting(
|
||||
self,
|
||||
incremental_state,
|
||||
new_order,
|
||||
):
|
||||
for module in incremental_state:
|
||||
for key in incremental_state[module]:
|
||||
result = incremental_state[module][key].index_select(0, new_order)
|
||||
incremental_state[module][key] = result
|
||||
|
||||
@register_model_architecture("lm", "lm_base")
|
||||
def base_lm_architecture(args):
|
||||
# backward compatibility for older model checkpoints
|
||||
if hasattr(args, "no_tie_adaptive_proj"):
|
||||
# previous models defined --no-tie-adaptive-proj, so use the existence of
|
||||
# that option to determine if this is an "old" model checkpoint
|
||||
args.no_decoder_final_norm = True # old models always set this to True
|
||||
if args.no_tie_adaptive_proj is False:
|
||||
args.tie_adaptive_proj = True
|
||||
if hasattr(args, "decoder_final_norm"):
|
||||
args.no_decoder_final_norm = not args.decoder_final_norm
|
||||
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||
|
||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
||||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||
|
||||
args.base_layers = getattr(args, "base_layers", 0)
|
||||
args.base_sublayers = getattr(args, "base_sublayers", 1)
|
||||
args.base_shuffle = getattr(args, "base_shuffle", False)
|
||||
|
||||
args.add_bos_token = getattr(args, "add_bos_token", False)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, "share_decoder_input_output_embed", False
|
||||
)
|
||||
args.character_embeddings = getattr(args, "character_embeddings", False)
|
||||
|
||||
args.decoder_output_dim = getattr(
|
||||
args, "decoder_output_dim", args.decoder_embed_dim
|
||||
)
|
||||
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||
|
||||
# Model training is not stable without this
|
||||
args.decoder_normalize_before = True
|
||||
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
|
||||
|
||||
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
|
||||
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
|
||||
|
||||
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
||||
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
||||
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||
args.offload_activations = getattr(args, "offload_activations", False)
|
||||
if args.offload_activations:
|
||||
args.checkpoint_activations = True
|
||||
|
450
examples/fairseq/models/machine_translation.py
Normal file
450
examples/fairseq/models/machine_translation.py
Normal file
|
@ -0,0 +1,450 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import utils as dist_utils, fsdp_wrap
|
||||
from fairseq import distributed_utils
|
||||
from fairseq import checkpoint_utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import Embedding
|
||||
from fairseq.modules import (
|
||||
AdaptiveSoftmax,
|
||||
FairseqDropout,
|
||||
LayerDropModuleList,
|
||||
LayerNorm,
|
||||
PositionalEmbedding,
|
||||
SinusoidalPositionalEmbedding,
|
||||
)
|
||||
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
||||
from .language_modeling import LMDecoder as MTDecoder
|
||||
|
||||
from torch import Tensor
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
||||
|
||||
|
||||
@register_model("mt")
|
||||
class TranslationModel(FairseqEncoderDecoderModel):
|
||||
|
||||
def __init__(self, args, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--activation-fn',
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help='activation function to use')
|
||||
parser.add_argument('--dropout', type=float, metavar='D',
|
||||
help='dropout probability')
|
||||
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
||||
help='dropout probability for attention weights')
|
||||
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
||||
help='dropout probability after activation in FFN.')
|
||||
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
||||
help='path to pre-trained encoder embedding')
|
||||
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
||||
help='encoder embedding dimension')
|
||||
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
||||
help='encoder embedding dimension for FFN')
|
||||
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
||||
help='num encoder layers')
|
||||
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
||||
help='num encoder attention heads')
|
||||
parser.add_argument('--encoder-normalize-before', action='store_true',
|
||||
help='apply layernorm before each encoder block')
|
||||
parser.add_argument('--encoder-learned-pos', action='store_true',
|
||||
help='use learned positional embeddings in the encoder')
|
||||
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
||||
help='path to pre-trained decoder embedding')
|
||||
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
||||
help='decoder embedding dimension')
|
||||
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
||||
help='decoder embedding dimension for FFN')
|
||||
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
||||
help='num decoder layers')
|
||||
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
||||
help='num decoder attention heads')
|
||||
parser.add_argument('--decoder-learned-pos', action='store_true',
|
||||
help='use learned positional embeddings in the decoder')
|
||||
parser.add_argument('--decoder-normalize-before', action='store_true',
|
||||
help='apply layernorm before each decoder block')
|
||||
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
|
||||
help='decoder output dimension (extra linear layer '
|
||||
'if different from decoder embed dim')
|
||||
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
||||
help='share decoder input and output embeddings')
|
||||
parser.add_argument('--share-all-embeddings', action='store_true',
|
||||
help='share encoder, decoder and output embeddings'
|
||||
' (requires shared dictionary and embed dim)')
|
||||
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
||||
help='if set, disables positional embeddings (outside self attention)')
|
||||
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
||||
help='comma separated list of adaptive softmax cutoff points. '
|
||||
'Must be used with adaptive_loss criterion'),
|
||||
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
||||
help='sets adaptive softmax dropout for the tail projections')
|
||||
parser.add_argument('--layernorm-embedding', action='store_true',
|
||||
help='add layernorm to embedding')
|
||||
parser.add_argument('--no-scale-embedding', action='store_true',
|
||||
help='if True, dont scale embeddings')
|
||||
parser.add_argument('--checkpoint-activations', action='store_true',
|
||||
help='checkpoint activations at each layer, which saves GPU '
|
||||
'memory usage at the cost of some additional compute')
|
||||
parser.add_argument('--offload-activations', action='store_true',
|
||||
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
|
||||
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
|
||||
parser.add_argument('--no-cross-attention', default=False, action='store_true',
|
||||
help='do not perform cross-attention')
|
||||
parser.add_argument('--cross-self-attention', default=False, action='store_true',
|
||||
help='perform cross+self-attention')
|
||||
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
||||
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
|
||||
help='LayerDrop probability for encoder')
|
||||
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
|
||||
help='LayerDrop probability for decoder')
|
||||
parser.add_argument('--encoder-layers-to-keep', default=None,
|
||||
help='which layers to *keep* when pruning as a comma-separated list')
|
||||
parser.add_argument('--decoder-layers-to-keep', default=None,
|
||||
help='which layers to *keep* when pruning as a comma-separated list')
|
||||
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
|
||||
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
|
||||
help='iterative PQ quantization noise at training time')
|
||||
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
|
||||
help='block size of quantization noise at training time')
|
||||
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
|
||||
help='scalar quantization noise and scalar quantization at training time')
|
||||
# args for Fully Sharded Data Parallel (FSDP) training
|
||||
parser.add_argument(
|
||||
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||
help=(
|
||||
'minimum number of params for a layer to be wrapped with FSDP() when '
|
||||
'training with --ddp-backend=fully_sharded. Smaller values will '
|
||||
'improve memory efficiency, but may make torch.distributed '
|
||||
'communication less efficient due to smaller input sizes. This option '
|
||||
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
|
||||
'--offload-activations are passed.'
|
||||
)
|
||||
)
|
||||
# args for mixture-of-expert layers
|
||||
parser.add_argument('--moe-freq', type=int, metavar='D', default=0,
|
||||
help='Frequency at which we insert MoE Transformer layers')
|
||||
parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0,
|
||||
help='Frequency at which we insert MoE Transformer encoder layers')
|
||||
parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0,
|
||||
help='Frequency at which we insert MoE Transformer decoder layers')
|
||||
parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0,
|
||||
help='Number of experts in each MoE Layer')
|
||||
parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true',
|
||||
help="Use FP32 computations in MoE top2 gating function")
|
||||
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
||||
help="policy for second expert, options: all/sampling/random")
|
||||
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
||||
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
|
||||
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
||||
help="MoE Expert FFN dimension")
|
||||
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
||||
help="Use top1 gate instead of top2")
|
||||
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
||||
help="Fraction of tokens as capacity during validation" + \
|
||||
"if set to negative, use same as training. range: (0.0, 1.0].")
|
||||
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
||||
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
||||
|
||||
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
||||
help="Don't route padding tokens to any expert")
|
||||
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
||||
help="Enable X-Moe")
|
||||
parser.add_argument('--freeze-moe', default=False, action='store_true',
|
||||
help="Freeze MoE Params")
|
||||
parser.add_argument('--deepnorm', default=False, action='store_true',
|
||||
help="Enable DeepNorm")
|
||||
parser.add_argument('--subln', default=False, action='store_true',
|
||||
help="Enable SubLN")
|
||||
parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='')
|
||||
# args for pseudo-MoE layers
|
||||
parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0,
|
||||
help="FFN embed dim of alternate pseudo-MoE blocks")
|
||||
parser.add_argument('--rel-pos-buckets', type=int, default=0,
|
||||
help='')
|
||||
parser.add_argument('--max-rel-pos', type=int, default=0,
|
||||
help='')
|
||||
# fmt: on
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
base_architecture(args)
|
||||
|
||||
if getattr(args, "max_source_positions", None) is None:
|
||||
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
||||
if getattr(args, "max_target_positions", None) is None:
|
||||
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
||||
|
||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||
|
||||
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
||||
|
||||
if args.share_all_embeddings:
|
||||
if src_dict != tgt_dict:
|
||||
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
||||
if args.encoder_embed_dim != args.decoder_embed_dim:
|
||||
raise ValueError(
|
||||
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
||||
)
|
||||
if args.decoder_embed_path and (
|
||||
args.decoder_embed_path != args.encoder_embed_path
|
||||
):
|
||||
raise ValueError(
|
||||
"--share-all-embeddings not compatible with --decoder-embed-path"
|
||||
)
|
||||
encoder_embed_tokens = cls.build_embedding(
|
||||
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
||||
)
|
||||
decoder_embed_tokens = encoder_embed_tokens
|
||||
args.share_decoder_input_output_embed = True
|
||||
else:
|
||||
encoder_embed_tokens = cls.build_embedding(
|
||||
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
||||
)
|
||||
decoder_embed_tokens = cls.build_embedding(
|
||||
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
||||
)
|
||||
if getattr(args, "offload_activations", False):
|
||||
args.checkpoint_activations = True # offloading implies checkpointing
|
||||
|
||||
encoder_embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_source_positions,
|
||||
args.encoder_embed_dim,
|
||||
src_dict.pad(),
|
||||
learned=args.encoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
decoder_embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_target_positions,
|
||||
args.decoder_embed_dim,
|
||||
tgt_dict.pad(),
|
||||
learned=args.decoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
if args.share_decoder_input_output_embed:
|
||||
output_projection = torch.nn.Linear(
|
||||
decoder_embed_tokens.weight.shape[1],
|
||||
decoder_embed_tokens.weight.shape[0],
|
||||
bias=False,
|
||||
)
|
||||
output_projection.weight = decoder_embed_tokens.weight
|
||||
else:
|
||||
output_projection = torch.nn.Linear(
|
||||
decoder_embed_dim, len(tgt_dict), bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
||||
)
|
||||
|
||||
encoder = cls.build_encoder(
|
||||
args,
|
||||
encoder_embed_tokens,
|
||||
encoder_embed_positions,
|
||||
src_dict,
|
||||
)
|
||||
decoder = cls.build_decoder(
|
||||
args,
|
||||
decoder_embed_tokens,
|
||||
decoder_embed_positions,
|
||||
output_projection,
|
||||
tgt_dict,
|
||||
)
|
||||
|
||||
if not args.share_all_embeddings:
|
||||
min_params_to_wrap = getattr(
|
||||
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
||||
)
|
||||
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
|
||||
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
|
||||
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
|
||||
return cls(args, encoder, decoder)
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
num_embeddings = len(dictionary)
|
||||
padding_idx = dictionary.pad()
|
||||
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
||||
# if provided, load from preloaded dictionaries
|
||||
if path:
|
||||
embed_dict = utils.parse_embedding(path)
|
||||
utils.load_embedding(embed_dict, dictionary, emb)
|
||||
return emb
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
||||
config = EncoderConfig()
|
||||
config.override(args)
|
||||
|
||||
return MTEncoder(
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
is_encoder_decoder=True,
|
||||
dictionary=dictionary,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
|
||||
config = DecoderConfig()
|
||||
config.override(args)
|
||||
|
||||
return MTDecoder(
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
output_projection,
|
||||
is_encoder_decoder=True,
|
||||
dictionary=dictionary,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
prev_output_tokens,
|
||||
return_all_hiddens: bool = False,
|
||||
features_only: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
features_only=features_only,
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
return decoder_out
|
||||
|
||||
def get_normalized_probs(
|
||||
self,
|
||||
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
||||
log_probs: bool,
|
||||
sample: Optional[Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
||||
|
||||
|
||||
class MTEncoder(Encoder, FairseqEncoder):
|
||||
|
||||
def forward(self, src_tokens, **kwargs):
|
||||
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
|
||||
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
|
||||
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
|
||||
|
||||
encoder_states = encoder_out["encoder_states"]
|
||||
if len(encoder_states) > 0:
|
||||
for idx, state in enumerate(encoder_states):
|
||||
encoder_states[idx] = state.index_select(1, new_order)
|
||||
|
||||
return {
|
||||
"encoder_out": new_encoder_out, # T x B x C
|
||||
"encoder_padding_mask": new_encoder_padding_mask,
|
||||
"encoder_embedding": new_encoder_embedding, # B x T x C
|
||||
"encoder_states": encoder_states, # List[T x B x C]
|
||||
}
|
||||
|
||||
def max_positions(self):
|
||||
return self.embed_positions.max_positions
|
||||
|
||||
@register_model_architecture("mt", "mt_base")
|
||||
def base_architecture(args):
|
||||
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
||||
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
||||
args.decoder_ffn_embed_dim = getattr(
|
||||
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
||||
)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, "share_decoder_input_output_embed", False
|
||||
)
|
||||
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||
args.no_cross_attention = getattr(args, "no_cross_attention", False)
|
||||
args.cross_self_attention = getattr(args, "cross_self_attention", False)
|
||||
|
||||
args.decoder_output_dim = getattr(
|
||||
args, "decoder_output_dim", args.decoder_embed_dim
|
||||
)
|
||||
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||
args.offload_activations = getattr(args, "offload_activations", False)
|
||||
if args.offload_activations:
|
||||
args.checkpoint_activations = True
|
||||
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
||||
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
||||
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
||||
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
||||
args.is_moe = getattr(args, "is_moe", False)
|
||||
args.selected_expert_count = getattr(args, "selected_expert_count", 2)
|
32
examples/fairseq/tasks/__init__.py
Normal file
32
examples/fairseq/tasks/__init__.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
|
||||
# register dataclass
|
||||
TASK_DATACLASS_REGISTRY = {}
|
||||
TASK_REGISTRY = {}
|
||||
TASK_CLASS_NAMES = set()
|
||||
|
||||
# automatically import any Python files in the tasks/ directory
|
||||
tasks_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(tasks_dir):
|
||||
path = os.path.join(tasks_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
task_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
module = importlib.import_module("tasks." + task_name)
|
||||
|
||||
# expose `task_parser` for sphinx
|
||||
if task_name in TASK_REGISTRY:
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
group_task = parser.add_argument_group("Task name")
|
||||
# fmt: off
|
||||
group_task.add_argument('--task', metavar=task_name,
|
||||
help='Enable this task with: ``--task=' + task_name + '``')
|
||||
# fmt: on
|
||||
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||
TASK_REGISTRY[task_name].add_args(group_args)
|
||||
globals()[task_name + "_parser"] = parser
|
0
examples/fairseq/tasks/data/__init__.py
Normal file
0
examples/fairseq/tasks/data/__init__.py
Normal file
78
examples/fairseq/tasks/data/basic_loader.py
Normal file
78
examples/fairseq/tasks/data/basic_loader.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
import math
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from infinibatch.iterators import CheckpointableIterator
|
||||
from . import utils
|
||||
|
||||
class BaseBatchGen(CheckpointableIterator):
|
||||
"""
|
||||
This is a base class for batch generators that use infinibatch
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._iter = None
|
||||
self.epoch = 1
|
||||
self.next_epoch_idx = 1
|
||||
self.sharded_checkpoint = True
|
||||
self.should_close_after_finished = True
|
||||
|
||||
def _build_iter(self):
|
||||
"""
|
||||
Build infinibatch iterator and assign to self._iter
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _move_to_tensor(self, batch):
|
||||
|
||||
def to_tensor(x):
|
||||
return torch.tensor(x)
|
||||
|
||||
return utils.apply_to_sample(to_tensor, batch)
|
||||
|
||||
@property
|
||||
def iterator(self):
|
||||
if self._iter is None:
|
||||
raise NotImplementedError("_build_iter() must called first")
|
||||
return self._iter
|
||||
|
||||
def __iter__(self):
|
||||
if self._iter is None:
|
||||
raise NotImplementedError("_build_iter() must called first")
|
||||
return self._iter
|
||||
|
||||
def __next__(self):
|
||||
return next(self._iter)
|
||||
|
||||
def setstate(self, value):
|
||||
self._iter.setstate(value)
|
||||
|
||||
def getstate(self):
|
||||
return self._iter.getstate()
|
||||
|
||||
def close(self):
|
||||
self._iter.close()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 819200000
|
||||
|
||||
def next_epoch_itr(
|
||||
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
|
||||
):
|
||||
return self
|
||||
|
||||
def end_of_epoch(self) -> bool:
|
||||
return False
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a dictionary containing a whole state of the iterator."""
|
||||
return self.getstate()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Copies the state of the iterator from the given *state_dict*."""
|
||||
self.setstate(state_dict)
|
||||
|
||||
@property
|
||||
def first_batch(self):
|
||||
return "DUMMY"
|
308
examples/fairseq/tasks/data/mlm_loader.py
Normal file
308
examples/fairseq/tasks/data/mlm_loader.py
Normal file
|
@ -0,0 +1,308 @@
|
|||
import glob
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import itertools
|
||||
import copy
|
||||
|
||||
from infinibatch import iterators
|
||||
from .basic_loader import BaseBatchGen
|
||||
from .utils import NativeCheckpointableIterator, WeightIterator
|
||||
|
||||
|
||||
class MLMLoader(BaseBatchGen):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dataset,
|
||||
dictionary,
|
||||
tokenizer,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.data = dataset.data
|
||||
self.data_dir = dataset.data_dir
|
||||
self.shuffle = dataset.shuffle
|
||||
self.dictionary = dictionary
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.max_tokens = max_tokens
|
||||
self.max_sentences = max_sentences
|
||||
self.max_positions = max_positions
|
||||
self.tokens_per_sample = args.tokens_per_sample
|
||||
self.sample_break_mode = args.sample_break_mode
|
||||
self.ignore_invalid_inputs = ignore_invalid_inputs
|
||||
self.required_batch_size_multiple = required_batch_size_multiple
|
||||
self.seed = str(seed)
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
|
||||
self.batch_read_ahead = args.batch_read_ahead
|
||||
|
||||
self._build_iter()
|
||||
|
||||
def _build_iter(self):
|
||||
tokenized_lines = self._multilingual_tokenize()
|
||||
self.padded_batches = self._batchify(tokenized_lines)
|
||||
|
||||
prefetch_batches = iterators.PrefetchIterator(
|
||||
self.padded_batches,
|
||||
buffer_size=10000,
|
||||
buffer_in_main_process=True,
|
||||
log_empty_buffer_warning=True and self.shard_id == 0,
|
||||
)
|
||||
|
||||
prefetch_batches = iterators.MapIterator(
|
||||
prefetch_batches, self._move_to_tensor
|
||||
)
|
||||
|
||||
self._iter = prefetch_batches
|
||||
|
||||
def _multilingual_tokenize(self):
|
||||
multilingual_iters = []
|
||||
weights = []
|
||||
|
||||
for data in self.data:
|
||||
multilingual_iters.append(
|
||||
self._tokenize(data)
|
||||
)
|
||||
if 'weight' in data:
|
||||
weights.append(float(data['weight']))
|
||||
else:
|
||||
weights.append(int(data['count']))
|
||||
|
||||
if len(multilingual_iters) == 1:
|
||||
return multilingual_iters[0]
|
||||
|
||||
sampling_iterator = WeightIterator(weights)
|
||||
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
||||
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
||||
|
||||
return tokenized_lines
|
||||
|
||||
def _tokenize(self, data):
|
||||
'''
|
||||
data:
|
||||
{
|
||||
'source': list[Path],
|
||||
'source_lang': str,
|
||||
'count': int,
|
||||
'weight': float,
|
||||
'name': str,
|
||||
}
|
||||
'''
|
||||
dataset = list(
|
||||
zip(
|
||||
data['source'],
|
||||
itertools.repeat(data['source_lang']),
|
||||
)
|
||||
)
|
||||
|
||||
if self.shuffle:
|
||||
chunk_files = \
|
||||
iterators.InfinitePermutationSourceIterator(
|
||||
dataset,
|
||||
seed=self.seed,
|
||||
shuffle=self.shuffle,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
else:
|
||||
chunk_files = \
|
||||
iterators.ChunkedSourceIterator(
|
||||
dataset,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
|
||||
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
||||
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
||||
|
||||
return tokenized_lines
|
||||
|
||||
|
||||
def _batchify(self, lines):
|
||||
|
||||
if self.max_sentences is not None:
|
||||
if self.batch_read_ahead > 0:
|
||||
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
||||
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
|
||||
else:
|
||||
def dynamic_batch_size(sample):
|
||||
lengths = [len(x) for x in sample]
|
||||
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||
return max(1, batch_size)
|
||||
|
||||
batches = iterators.BucketedReadaheadBatchIterator(
|
||||
lines,
|
||||
read_ahead=self.batch_read_ahead,
|
||||
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||
batch_size=dynamic_batch_size,
|
||||
shuffle=self.shuffle,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
def collate(batch):
|
||||
batch_size = len(batch)
|
||||
|
||||
mlm_source_max_length = max([len(x[0]) for x in batch])
|
||||
mlm_target_max_length = max([len(x[1]) for x in batch])
|
||||
s2s_source_max_length = max([len(x[2]) for x in batch])
|
||||
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||
|
||||
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
|
||||
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
||||
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
||||
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
|
||||
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
||||
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
||||
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
||||
|
||||
ret_batch = {
|
||||
'net_input': {
|
||||
'src_tokens': mlm_source_ids.astype(np.int64),
|
||||
},
|
||||
'target': mlm_target_ids.astype(np.int64),
|
||||
'nsentences': batch_size,
|
||||
'ntokens': sum([len(x[0]) for x in batch]),
|
||||
}
|
||||
|
||||
return ret_batch
|
||||
|
||||
padded_batches = iterators.MapIterator(
|
||||
batches, collate
|
||||
)
|
||||
|
||||
return padded_batches
|
||||
|
||||
def _prepare(self, _random, doc):
|
||||
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
|
||||
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
|
||||
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
|
||||
|
||||
def _mask_lm(self, _random, doc):
|
||||
def mask_tokens():
|
||||
return f"<mask>"
|
||||
|
||||
length = len(doc)
|
||||
mask_tokens_num = int(length * self.args.mask_prob)
|
||||
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
|
||||
possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
|
||||
possible_mask_positions = sorted(possible_mask_positions)
|
||||
|
||||
nonmasked_tokens = copy.deepcopy(doc)
|
||||
masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]
|
||||
|
||||
for position in possible_mask_positions:
|
||||
# masked_tokens.append(nonmasked_tokens[position])
|
||||
masked_tokens[position] = nonmasked_tokens[position]
|
||||
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
|
||||
|
||||
return nonmasked_tokens, masked_tokens
|
||||
|
||||
def _span_corruption(self, _random, doc):
|
||||
|
||||
def mask_tokens(i):
|
||||
return f"<mask_{i}>"
|
||||
|
||||
length = len(doc)
|
||||
noise_tokens_num = int(length * self.args.mask_prob)
|
||||
noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
|
||||
noise_spans_num = int(noise_tokens_num / self.args.span_length)
|
||||
noise_spans_num = max(noise_spans_num, 1)
|
||||
nonnoise_tokens_num = length - noise_tokens_num
|
||||
|
||||
if noise_spans_num == 1:
|
||||
noise_split_positions = [0, noise_tokens_num]
|
||||
else:
|
||||
possible_split_positions = list(range(1, noise_tokens_num))
|
||||
_random.shuffle(possible_split_positions)
|
||||
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
||||
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
||||
|
||||
possible_insert_positions = list(range(nonnoise_tokens_num))
|
||||
_random.shuffle(possible_insert_positions)
|
||||
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
|
||||
|
||||
nonnoise_spans, noise_spans = [], []
|
||||
last_end = 0
|
||||
for i in range(noise_spans_num):
|
||||
start_pos = noise_insert_positions[i] + noise_split_positions[i]
|
||||
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
|
||||
mask_id = self.dictionary.indices[mask_tokens(i)]
|
||||
|
||||
if getattr(self.args, "remove_target_sentinel", False):
|
||||
noise_spans.append(doc[start_pos:end_pos])
|
||||
else:
|
||||
noise_spans.append([mask_id] + doc[start_pos:end_pos])
|
||||
|
||||
if getattr(self.args, "remove_source_sentinel", False):
|
||||
nonnoise_spans.extend(doc[last_end:start_pos])
|
||||
else:
|
||||
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
|
||||
|
||||
last_end = end_pos
|
||||
|
||||
nonnoise_spans.extend(doc[last_end:])
|
||||
noise_spans = sum(noise_spans, [])
|
||||
|
||||
return nonnoise_spans, noise_spans
|
||||
|
||||
def _read_from_files(self, source_file, source_lang):
|
||||
# data = []
|
||||
file_path = os.path.join(self.data_dir, source_file)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print('| file {} not exists'.format(file_path), flush=True)
|
||||
return iter([]) # skip bad file
|
||||
|
||||
with open(file_path, 'r', encoding='utf8') as f:
|
||||
lines = f.read().strip().split('\n')
|
||||
|
||||
doc = [self.dictionary.bos()]
|
||||
for line in lines:
|
||||
if line == "":
|
||||
if self.sample_break_mode == 'complete_doc':
|
||||
# data.append(doc)
|
||||
yield doc
|
||||
doc = [self.dictionary.bos()]
|
||||
continue
|
||||
|
||||
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
||||
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
||||
|
||||
if len(tokenized_id) > self.tokens_per_sample:
|
||||
continue
|
||||
if len(doc) + len(tokenized_id) > self.tokens_per_sample:
|
||||
# data.append(doc)
|
||||
yield doc
|
||||
doc = [self.dictionary.bos()]
|
||||
doc.extend(tokenized_id)
|
||||
|
||||
if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
|
||||
# data.append(doc)
|
||||
yield doc
|
||||
|
||||
# return data
|
82
examples/fairseq/tasks/data/utils.py
Normal file
82
examples/fairseq/tasks/data/utils.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import gzip
|
||||
import numpy as np
|
||||
from random import Random
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
import collections
|
||||
from infinibatch import iterators
|
||||
|
||||
def apply_to_sample(f, sample):
|
||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||
return {}
|
||||
|
||||
def _apply(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return f(x)
|
||||
elif isinstance(x, collections.OrderedDict):
|
||||
# OrderedDict has attributes that needs to be preserved
|
||||
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
|
||||
od.__dict__ = x.__dict__
|
||||
return od
|
||||
elif isinstance(x, dict):
|
||||
return {key: _apply(value) for key, value in x.items()}
|
||||
elif isinstance(x, list):
|
||||
return [_apply(x) for x in x]
|
||||
elif isinstance(x, tuple):
|
||||
return tuple(_apply(x) for x in x)
|
||||
elif isinstance(x, set):
|
||||
return {_apply(x) for x in x}
|
||||
else:
|
||||
return x
|
||||
|
||||
return _apply(sample)
|
||||
|
||||
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||
def __init__(self, iterable: Iterable):
|
||||
self._input_iterable = iterable
|
||||
self.setstate(None)
|
||||
|
||||
def getstate(self) -> Dict:
|
||||
return {'num_items_yielded': self._num_items_yielded}
|
||||
|
||||
def setstate(self, checkpoint: Optional[Dict]):
|
||||
self._iterator = iter(self._input_iterable)
|
||||
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
|
||||
|
||||
def __next__(self):
|
||||
item = next(self._iterator)
|
||||
self._num_items_yielded += 1
|
||||
return item
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class WeightIterator(object):
|
||||
def __init__(self, weights, seed):
|
||||
self.weights = weights
|
||||
self.seed = seed
|
||||
self.control_index = list(range(len(weights)))
|
||||
self.setstate(None)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def getstate(self):
|
||||
return {"random_state": self._random_state}
|
||||
|
||||
def setstate(self, checkpoint):
|
||||
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||||
self._random = None # this will trigger the lazy initialization in self.__next__
|
||||
|
||||
def __next__(self):
|
||||
if self._random is None:
|
||||
self._random = Random(self.seed)
|
||||
if self._random_state is not None:
|
||||
self._random.setstate(self._random_state)
|
||||
idx = self._random.choices(self.control_index, self.weights)[0]
|
||||
self._random_state = self._random.getstate()
|
||||
return idx
|
||||
|
||||
def close(self):
|
||||
pass
|
207
examples/fairseq/tasks/pretraining.py
Normal file
207
examples/fairseq/tasks/pretraining.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
import json
|
||||
from omegaconf import MISSING, II, OmegaConf
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from fairseq import utils
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from .data.mlm_loader import MLMLoader
|
||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
||||
import sentencepiece as spm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
|
||||
@dataclass
|
||||
class PretrainingConfig(FairseqDataclass):
|
||||
data: str = field(
|
||||
default=MISSING,
|
||||
metadata={
|
||||
"help": "colon separated path to data directories list, \
|
||||
will be iterated upon during epochs in round-robin manner"
|
||||
},
|
||||
)
|
||||
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||
default="complete",
|
||||
metadata={
|
||||
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||
'tokens. If set to "complete", splits samples only at the end '
|
||||
"of sentence, but may include multiple sentences per sample. "
|
||||
'"complete_doc" is similar but respects doc boundaries. '
|
||||
'If set to "eos", includes only one sentence per sample.'
|
||||
},
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
mask_prob: float = field(
|
||||
default=0.15,
|
||||
metadata={"help": "probability of replacing a token with mask"},
|
||||
)
|
||||
leave_unmasked_prob: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "probability that a masked token is unmasked"},
|
||||
)
|
||||
random_token_prob: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "probability of replacing a token with a random token"},
|
||||
)
|
||||
freq_weighted_replacement: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "sample random replacement words based on word frequencies"},
|
||||
)
|
||||
mask_whole_words: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "mask whole words; you may also want to set --bpe"},
|
||||
)
|
||||
mask_multiple_length: int = field(
|
||||
default=1,
|
||||
metadata={"help": "repeat the mask indices multiple times"},
|
||||
)
|
||||
mask_stdev: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "stdev of the mask length"},
|
||||
)
|
||||
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||
},
|
||||
)
|
||||
shorten_data_split_list: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||
'e.g., "train,valid" (default: all dataset splits)'
|
||||
},
|
||||
)
|
||||
seed: int = II("common.seed")
|
||||
span_length: float = field(
|
||||
default=3.0,
|
||||
metadata={"help": "average span length for masking"},
|
||||
)
|
||||
remove_source_sentinel: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "remove the source sentinel for the span corruption task"},
|
||||
)
|
||||
remove_target_sentinel: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "remove the target sentinel for the span corruption task"},
|
||||
)
|
||||
batch_read_ahead: int = field(
|
||||
default=100000,
|
||||
metadata={"help": "batch read ahead size for infinibatch"},
|
||||
)
|
||||
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
|
||||
spm_model: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "sentencepice model to tokenize the data"
|
||||
},
|
||||
)
|
||||
dict_file: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": ""
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@register_task("pretraining", dataclass=PretrainingConfig)
|
||||
class PLMTask(FairseqTask):
|
||||
|
||||
def __init__(self, cfg, dictionary, tokenizer):
|
||||
super().__init__(cfg)
|
||||
self.cfg = cfg
|
||||
self.dictionary = dictionary
|
||||
self.tokenizer = tokenizer
|
||||
self.seed = cfg.seed
|
||||
self.mask_idx = dictionary.index("<mask>")
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, cfg, **kwargs):
|
||||
paths = utils.split_paths(cfg.data)
|
||||
assert len(paths) > 0
|
||||
if cfg.dict_file != "":
|
||||
dictionary = Dictionary.load(cfg.dict_file)
|
||||
else:
|
||||
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
|
||||
# add mask token
|
||||
dictionary.add_symbol("<mask>")
|
||||
for i in range(100):
|
||||
dictionary.add_symbol(f"<mask_{i}>")
|
||||
|
||||
dictionary.pad_to_multiple_(cfg.required_batch_size_multiple)
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
|
||||
# tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model))
|
||||
tokenizer = spm.SentencePieceProcessor()
|
||||
tokenizer.Load(cfg.spm_model)
|
||||
return cls(cfg, dictionary, tokenizer)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
self.datasets[split] = {
|
||||
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
|
||||
'data_dir': self.cfg.data,
|
||||
'shuffle': True if split == 'train' else False,
|
||||
}
|
||||
self.datasets[split] = Namespace(**self.datasets[split])
|
||||
|
||||
def dataset(self, split):
|
||||
if split not in self.datasets:
|
||||
raise KeyError("Dataset not loaded: " + split)
|
||||
|
||||
return self.datasets[split]
|
||||
|
||||
def get_batch_iterator(
|
||||
self,
|
||||
dataset,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
num_workers=0,
|
||||
epoch=1,
|
||||
data_buffer_size=0,
|
||||
disable_iterator_cache=False,
|
||||
):
|
||||
return MLMLoader(
|
||||
self.cfg,
|
||||
dataset,
|
||||
self.dictionary,
|
||||
self.tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=ignore_invalid_inputs,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
8
examples/fairseq/train.py
Normal file
8
examples/fairseq/train.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.train import cli_main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
0
examples/fairseq/utils/__init__.py
Normal file
0
examples/fairseq/utils/__init__.py
Normal file
75
examples/fairseq/utils/sparse_clip.py
Normal file
75
examples/fairseq/utils/sparse_clip.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
import warnings
|
||||
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
|
||||
@torch.no_grad()
|
||||
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
||||
def grad_exists(p):
|
||||
return p is not None and getattr(p, "grad", None) is not None
|
||||
if isinstance(params, torch.Tensor):
|
||||
params = [params]
|
||||
params = list(params)
|
||||
params = list(filter(grad_exists, params))
|
||||
grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
|
||||
denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
|
||||
for p in params:
|
||||
if hasattr(p, "expert"):
|
||||
expert_grads.append(p.grad.detach() / denom)
|
||||
elif hasattr(p, "base_expert"):
|
||||
base_expert_grads.append(p.grad.detach())
|
||||
elif hasattr(p, "_is_sharded"):
|
||||
sharded_grads.append(p.grad.detach())
|
||||
else:
|
||||
grads.append(p.grad.detach())
|
||||
if len(grads) == 0:
|
||||
if len(params) > 0:
|
||||
total_norm = params[0].new_tensor(0.0)
|
||||
else:
|
||||
total_norm = torch.tensor(0.0)
|
||||
elif len(grads) == 1:
|
||||
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
|
||||
else:
|
||||
if multi_tensor_l2norm_available:
|
||||
total_norm = multi_tensor_total_norm(grads)
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn(
|
||||
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
|
||||
"you may get better performance by installing NVIDIA's apex library"
|
||||
)
|
||||
device = torch.cuda.current_device()
|
||||
elif grads[0].device.type == "xla":
|
||||
device = grads[0].device
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
total_norm = torch.norm(
|
||||
torch.stack(
|
||||
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
|
||||
)
|
||||
)
|
||||
|
||||
# calculate split_norm and all_reduce with other workers
|
||||
norms = [total_norm]
|
||||
for split_grads in [expert_grads, sharded_grads]:
|
||||
if len(split_grads) == 0:
|
||||
continue
|
||||
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]))
|
||||
if dist.is_initialized():
|
||||
split_norm.pow_(2)
|
||||
dist.all_reduce(split_norm)
|
||||
split_norm.sqrt_()
|
||||
norms.append(split_norm)
|
||||
if len(norms) > 1:
|
||||
total_norm = torch.norm(torch.stack(norms))
|
||||
|
||||
if aggregate_norm_fn is not None:
|
||||
total_norm = aggregate_norm_fn(total_norm)
|
||||
|
||||
if max_norm > 0:
|
||||
max_norm = float(max_norm)
|
||||
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
||||
for g in grads + expert_grads + sharded_grads + base_expert_grads:
|
||||
g.mul_(clip_coef)
|
||||
return total_norm
|
25
setup.py
Normal file
25
setup.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from io import open
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="torchscale",
|
||||
version="0.1.1",
|
||||
author="TorchScale Team",
|
||||
author_email="Shuming.Ma@microsoft.com",
|
||||
description="Transformers at any scale",
|
||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="Transformers at any scale",
|
||||
license="MIT",
|
||||
url="https://github.com/msranlp/torchscale",
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
||||
"tests.*", "tests"]),
|
||||
install_requires=['apex',
|
||||
'torch>=1.8',
|
||||
'fairscale==0.4.0',
|
||||
'timm==0.4.12'],
|
||||
python_requires='>=3.8.0',
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
],
|
||||
)
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
33
tests/test_decoder.py
Normal file
33
tests/test_decoder.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import pytest
|
||||
from torchscale.architecture.config import DecoderConfig
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
import torch
|
||||
|
||||
testcases = [
|
||||
{},
|
||||
{"vocab_size": 64000},
|
||||
{"activation_fn": "relu"},
|
||||
{"drop_path_rate": 0.1},
|
||||
{"decoder_normalize_before": False},
|
||||
{"no_scale_embedding": False},
|
||||
{"layernorm_embedding": True},
|
||||
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||
{"deepnorm": True, "subln": False, "decoder_normalize_before": False},
|
||||
{"bert_init": True},
|
||||
{"multiway": True},
|
||||
{"share_decoder_input_output_embed": True},
|
||||
{"checkpoint_activations": True},
|
||||
{"fsdp": True}
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("args", testcases)
|
||||
def test_decoder(args):
|
||||
config = DecoderConfig(**args)
|
||||
model = Decoder(config)
|
||||
prev_output_tokens = torch.ones(2, 10)
|
||||
token_embeddings = torch.rand(2, 10, config.decoder_embed_dim)
|
||||
model(
|
||||
prev_output_tokens=prev_output_tokens,
|
||||
token_embeddings=token_embeddings,
|
||||
features_only=True,
|
||||
)
|
28
tests/test_encoder.py
Normal file
28
tests/test_encoder.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import pytest
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
import torch
|
||||
|
||||
testcases = [
|
||||
{},
|
||||
{"vocab_size": 64000},
|
||||
{"activation_fn": "relu"},
|
||||
{"drop_path_rate": 0.1},
|
||||
{"encoder_normalize_before": False},
|
||||
{"no_scale_embedding": False},
|
||||
{"layernorm_embedding": True},
|
||||
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||
{"deepnorm": True, "subln": False, "encoder_normalize_before": False},
|
||||
{"bert_init": True},
|
||||
{"multiway": True},
|
||||
{"share_encoder_input_output_embed": True},
|
||||
{"checkpoint_activations": True},
|
||||
{"fsdp": True}
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("args", testcases)
|
||||
def test_encoder(args):
|
||||
config = EncoderConfig(**args)
|
||||
model = Encoder(config)
|
||||
token_embeddings = torch.rand(2, 10, config.encoder_embed_dim)
|
||||
model(src_tokens=None, token_embeddings=token_embeddings)
|
43
tests/test_encoder_decoder.py
Normal file
43
tests/test_encoder_decoder.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from torchscale.architecture.config import EncoderDecoderConfig
|
||||
from torchscale.architecture.encoder_decoder import EncoderDecoder
|
||||
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding
|
||||
import torch
|
||||
|
||||
testcases = [
|
||||
{},
|
||||
{"vocab_size": 64000},
|
||||
{"activation_fn": "relu"},
|
||||
{"drop_path_rate": 0.1},
|
||||
{"encoder_normalize_before": False, "decoder_normalize_before": False},
|
||||
{"no_scale_embedding": False},
|
||||
{"layernorm_embedding": True},
|
||||
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False},
|
||||
{"bert_init": True},
|
||||
{"multiway": True},
|
||||
{"share_decoder_input_output_embed": True},
|
||||
{"share_all_embeddings": True},
|
||||
{"checkpoint_activations": True},
|
||||
{"fsdp": True}
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("args", testcases)
|
||||
def test_decoder(args):
|
||||
config = EncoderDecoderConfig(**args)
|
||||
model = EncoderDecoder(
|
||||
config,
|
||||
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
||||
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
||||
)
|
||||
|
||||
src_tokens = torch.ones(2, 20).long()
|
||||
prev_output_tokens = torch.ones(2, 10).long()
|
||||
|
||||
model(
|
||||
src_tokens=src_tokens,
|
||||
prev_output_tokens=prev_output_tokens,
|
||||
features_only=True,
|
||||
)
|
0
torchscale/__init__.py
Normal file
0
torchscale/__init__.py
Normal file
0
torchscale/architecture/__init__.py
Normal file
0
torchscale/architecture/__init__.py
Normal file
160
torchscale/architecture/config.py
Normal file
160
torchscale/architecture/config.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
|
||||
|
||||
class EncoderConfig(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||
self.dropout = kwargs.pop("dropout", 0.0)
|
||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||
self.use_xmoe = kwargs.pop("use_xmoe", True)
|
||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||
self.subln = kwargs.pop("subln", True)
|
||||
self.bert_init = kwargs.pop("bert_init", False)
|
||||
self.multiway = kwargs.pop("multiway", False)
|
||||
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
|
||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
# Text
|
||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||
# Vision
|
||||
self.img_size = kwargs.pop("img_size", 224)
|
||||
self.patch_size = kwargs.pop("patch_size", 16)
|
||||
self.in_chans = kwargs.pop("in_chans", 3)
|
||||
# Fairscale
|
||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||
self.fsdp = kwargs.pop("fsdp", False)
|
||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||
|
||||
if self.deepnorm:
|
||||
self.encoder_normalize_before = False
|
||||
if self.subln:
|
||||
self.encoder_normalize_before = True
|
||||
|
||||
def override(self, args):
|
||||
for hp in self.__dict__.keys():
|
||||
if getattr(args, hp, None) is not None:
|
||||
self.__dict__[hp] = getattr(args, hp, None)
|
||||
|
||||
|
||||
class DecoderConfig(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||
self.dropout = kwargs.pop("dropout", 0.0)
|
||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||
self.use_xmoe = kwargs.pop("use_xmoe", True)
|
||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||
self.subln = kwargs.pop("subln", True)
|
||||
self.bert_init = kwargs.pop("bert_init", False)
|
||||
self.multiway = kwargs.pop("multiway", False)
|
||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
# Text
|
||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||
# Fairscale
|
||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||
self.fsdp = kwargs.pop("fsdp", False)
|
||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||
|
||||
if self.deepnorm:
|
||||
self.decoder_normalize_before = False
|
||||
if self.subln:
|
||||
self.decoder_normalize_before = True
|
||||
|
||||
def override(self, args):
|
||||
for hp in self.__dict__.keys():
|
||||
if getattr(args, hp, None) is not None:
|
||||
self.__dict__[hp] = getattr(args, hp, None)
|
||||
|
||||
|
||||
class EncoderDecoderConfig(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||
self.dropout = kwargs.pop("dropout", 0.0)
|
||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||
self.use_xmoe = kwargs.pop("use_xmoe", True)
|
||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||
self.subln = kwargs.pop("subln", True)
|
||||
self.bert_init = kwargs.pop("bert_init", False)
|
||||
self.multiway = kwargs.pop("multiway", False)
|
||||
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
# Text
|
||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||
# Fairscale
|
||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||
self.fsdp = kwargs.pop("fsdp", False)
|
||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||
|
||||
if self.deepnorm:
|
||||
self.encoder_normalize_before = False
|
||||
self.decoder_normalize_before = False
|
||||
if self.subln:
|
||||
self.encoder_normalize_before = True
|
||||
self.decoder_normalize_before = True
|
||||
|
||||
def override(self, args):
|
||||
for hp in self.__dict__.keys():
|
||||
if getattr(args, hp, None) is not None:
|
||||
self.__dict__[hp] = getattr(args, hp, None)
|
447
torchscale/architecture/decoder.py
Normal file
447
torchscale/architecture/decoder.py
Normal file
|
@ -0,0 +1,447 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||
from torchscale.component.multihead_attention import MultiheadAttention
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.embed_dim = args.decoder_embed_dim
|
||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
if args.drop_path_rate > 0:
|
||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
|
||||
self.drop_path = DropPath(drop_path_prob)
|
||||
else:
|
||||
self.drop_path = None
|
||||
|
||||
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
||||
|
||||
self.normalize_before = args.decoder_normalize_before
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
if not is_encoder_decoder:
|
||||
self.encoder_attn = None
|
||||
self.encoder_attn_layer_norm = None
|
||||
else:
|
||||
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
self.is_moe_layer = is_moe_layer
|
||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||
|
||||
if not self.is_moe_layer:
|
||||
self.ffn = self.build_ffn(
|
||||
self.embed_dim,
|
||||
self.args,
|
||||
)
|
||||
else:
|
||||
if args.moe_top1_expert:
|
||||
gate = Top1Gate(
|
||||
self.embed_dim,
|
||||
args.moe_expert_count,
|
||||
use_fp32=args.moe_gating_use_fp32,
|
||||
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
||||
use_xmoe=args.use_xmoe,
|
||||
)
|
||||
else:
|
||||
gate = Top2Gate(
|
||||
self.embed_dim,
|
||||
args.moe_expert_count,
|
||||
args.moe_gating_use_fp32,
|
||||
args.moe_second_expert_policy,
|
||||
args.moe_normalize_gate_prob_before_dropping,
|
||||
args.moe_eval_capacity_token_fraction,
|
||||
use_xmoe=args.use_xmoe,
|
||||
)
|
||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||
self.moe_layer = MOELayer(gate, experts, args)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
|
||||
else:
|
||||
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
|
||||
if args.subln:
|
||||
self.ffn_layernorm = LayerNorm(self.ffn_dim)
|
||||
else:
|
||||
self.ffn_layernorm = None
|
||||
|
||||
def build_ffn(self, embed_dim, args):
|
||||
return FeedForwardNetwork(
|
||||
embed_dim,
|
||||
self.ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.subln,
|
||||
)
|
||||
|
||||
def build_self_attention(self, embed_dim, args):
|
||||
return MultiheadAttention(
|
||||
args,
|
||||
embed_dim,
|
||||
args.decoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
self_attention=True,
|
||||
encoder_decoder_attention=False,
|
||||
subln=args.subln,
|
||||
)
|
||||
|
||||
def build_encoder_attention(self, embed_dim, args):
|
||||
return MultiheadAttention(
|
||||
args,
|
||||
embed_dim,
|
||||
args.decoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=True,
|
||||
subln=args.subln,
|
||||
)
|
||||
|
||||
def residual_connection(self, x, residual):
|
||||
return residual * self.alpha + x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_out=None,
|
||||
encoder_padding_mask=None,
|
||||
incremental_state=None,
|
||||
self_attn_mask=None,
|
||||
self_attn_padding_mask=None,
|
||||
self_attn_rel_pos=None,
|
||||
cross_attn_rel_pos=None,
|
||||
):
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
attn_mask=self_attn_mask,
|
||||
rel_pos=self_attn_rel_pos,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
if self.encoder_attn is not None and encoder_out is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
x, attn = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_out,
|
||||
value=encoder_out,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
incremental_state=None,
|
||||
rel_pos=cross_attn_rel_pos,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.is_moe_layer:
|
||||
x = self.ffn(x)
|
||||
l_aux = None
|
||||
else:
|
||||
x = x.transpose(0, 1)
|
||||
x, l_aux = self.moe_layer(x)
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, attn, None, l_aux
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
embed_tokens=None,
|
||||
embed_positions=None,
|
||||
output_projection=None,
|
||||
is_encoder_decoder=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.args = args
|
||||
|
||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
embed_dim = args.decoder_embed_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = embed_positions
|
||||
|
||||
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
|
||||
self.output_projection = self.build_output_projection(args)
|
||||
else:
|
||||
self.output_projection = output_projection
|
||||
|
||||
if args.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
moe_freq = args.moe_freq
|
||||
for i in range(args.decoder_layers):
|
||||
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
||||
self.layers.append(
|
||||
self.build_decoder_layer(
|
||||
args,
|
||||
depth=i,
|
||||
is_moe_layer=is_moe_layer,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
)
|
||||
)
|
||||
|
||||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.output_projection = output_projection
|
||||
|
||||
self.self_attn_relative_position = None
|
||||
self.cross_attn_relative_position = None
|
||||
|
||||
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
||||
self.self_attn_relative_position = RelativePositionBias(
|
||||
num_buckets=args.rel_pos_buckets,
|
||||
max_distance=args.max_rel_pos,
|
||||
n_heads=args.decoder_attention_heads,
|
||||
)
|
||||
if is_encoder_decoder:
|
||||
self.cross_attn_relative_position = RelativePositionBias(
|
||||
num_buckets=args.rel_pos_buckets,
|
||||
max_distance=args.max_rel_pos,
|
||||
n_heads=args.decoder_attention_heads,
|
||||
)
|
||||
|
||||
if args.bert_init:
|
||||
self.apply(init_bert_params)
|
||||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
|
||||
else:
|
||||
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
p.data.div_(init_scale)
|
||||
|
||||
if args.subln:
|
||||
if is_encoder_decoder:
|
||||
init_scale = math.sqrt(math.log(args.decoder_layers * 3))
|
||||
else:
|
||||
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if 'encoder_attn' in name:
|
||||
continue
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
def build_output_projection(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
if args.share_decoder_input_output_embed:
|
||||
output_projection = torch.nn.Linear(
|
||||
self.embed_tokens.weight.shape[1],
|
||||
self.embed_tokens.weight.shape[0],
|
||||
bias=False,
|
||||
)
|
||||
output_projection.weight = self.embed_tokens.weight
|
||||
else:
|
||||
output_projection = torch.nn.Linear(
|
||||
args.decoder_embed_dim, args.vocab_size, bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||
)
|
||||
return output_projection
|
||||
|
||||
def build_decoder_layer(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False
|
||||
):
|
||||
layer = DecoderLayer(
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=is_moe_layer,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
)
|
||||
if args.checkpoint_activations:
|
||||
layer = checkpoint_wrapper(layer)
|
||||
if args.fsdp:
|
||||
layer = wrap(layer)
|
||||
return layer
|
||||
|
||||
def forward_embedding(
|
||||
self,
|
||||
tokens,
|
||||
token_embedding=None,
|
||||
incremental_state=None,
|
||||
):
|
||||
positions = None
|
||||
if self.embed_positions is not None:
|
||||
positions = self.embed_positions(tokens, incremental_state=incremental_state)
|
||||
|
||||
if incremental_state is not None:
|
||||
tokens = tokens[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
||||
if token_embedding is None:
|
||||
token_embedding = self.embed_tokens(tokens)
|
||||
|
||||
x = embed = self.embed_scale * token_embedding
|
||||
|
||||
if positions is not None:
|
||||
x += positions
|
||||
|
||||
if self.layernorm_embedding is not None:
|
||||
x = self.layernorm_embedding(x)
|
||||
|
||||
x = self.dropout_module(x)
|
||||
|
||||
return x, embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
self_attn_padding_mask=None,
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
features_only=False,
|
||||
return_all_hiddens=False,
|
||||
token_embeddings=None,
|
||||
**kwargs
|
||||
):
|
||||
# embed tokens and positions
|
||||
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# relative postion
|
||||
self_attn_rel_pos_bias = None
|
||||
slen = prev_output_tokens.size(1)
|
||||
if self.self_attn_relative_position is not None:
|
||||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||
batch_size=x.size(1),
|
||||
qlen=slen,
|
||||
klen=slen
|
||||
)
|
||||
if incremental_state is not None:
|
||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
|
||||
cross_attn_rel_pos_bias = None
|
||||
if self.cross_attn_relative_position is not None:
|
||||
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
|
||||
batch_size=x.size(1),
|
||||
qlen=slen,
|
||||
klen=encoder_out["encoder_out"].size(0),
|
||||
)
|
||||
if incremental_state is not None:
|
||||
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :]
|
||||
|
||||
# decoder layers
|
||||
inner_states = [x]
|
||||
|
||||
if encoder_out is None:
|
||||
l_aux = []
|
||||
else:
|
||||
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
if incremental_state is None:
|
||||
self_attn_mask = torch.triu(
|
||||
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1
|
||||
)
|
||||
else:
|
||||
self_attn_mask = None
|
||||
if idx not in incremental_state:
|
||||
incremental_state[idx] = {}
|
||||
|
||||
x, layer_attn, _, l_aux_i = layer(
|
||||
x,
|
||||
encoder_out["encoder_out"] if encoder_out is not None else None,
|
||||
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
|
||||
incremental_state[idx] if incremental_state is not None else None,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
self_attn_rel_pos=self_attn_rel_pos_bias,
|
||||
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
||||
)
|
||||
l_aux.append(l_aux_i)
|
||||
inner_states.append(x)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]}
|
||||
|
||||
def output_layer(self, features):
|
||||
return self.output_projection(features)
|
367
torchscale/architecture/encoder.py
Normal file
367
torchscale/architecture/encoder.py
Normal file
|
@ -0,0 +1,367 @@
|
|||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||
from torchscale.component.multihead_attention import MultiheadAttention
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
||||
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
if args.drop_path_rate > 0:
|
||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
|
||||
self.drop_path = DropPath(drop_path_prob)
|
||||
else:
|
||||
self.drop_path = None
|
||||
|
||||
self.normalize_before = args.encoder_normalize_before
|
||||
self.is_moe_layer = is_moe_layer
|
||||
self.ffn_dim = args.encoder_ffn_embed_dim
|
||||
|
||||
if not self.is_moe_layer:
|
||||
self.ffn = MultiwayWrapper(
|
||||
args,
|
||||
self.build_ffn(
|
||||
self.embed_dim,
|
||||
self.args,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert not self.args.multiway
|
||||
if args.moe_top1_expert:
|
||||
gate = Top1Gate(
|
||||
self.embed_dim,
|
||||
args.moe_expert_count,
|
||||
use_fp32=args.moe_gating_use_fp32,
|
||||
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
||||
use_xmoe=args.use_xmoe,
|
||||
)
|
||||
else:
|
||||
gate = Top2Gate(
|
||||
self.embed_dim,
|
||||
args.moe_expert_count,
|
||||
args.moe_gating_use_fp32,
|
||||
args.moe_second_expert_policy,
|
||||
args.moe_normalize_gate_prob_before_dropping,
|
||||
args.moe_eval_capacity_token_fraction,
|
||||
use_xmoe=args.use_xmoe,
|
||||
)
|
||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||
self.moe_layer = MOELayer(gate, experts, args)
|
||||
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
|
||||
else:
|
||||
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
||||
else:
|
||||
self.alpha = 1.0
|
||||
|
||||
def build_ffn(self, embed_dim, args):
|
||||
return FeedForwardNetwork(
|
||||
embed_dim,
|
||||
self.ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.subln,
|
||||
)
|
||||
|
||||
def build_self_attention(self, embed_dim, args):
|
||||
return MultiheadAttention(
|
||||
args,
|
||||
embed_dim,
|
||||
args.encoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
self_attention=True,
|
||||
encoder_decoder_attention=False,
|
||||
subln=args.subln,
|
||||
)
|
||||
|
||||
def residual_connection(self, x, residual):
|
||||
return residual * self.alpha + x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_padding_mask,
|
||||
attn_mask=None,
|
||||
rel_pos=None
|
||||
):
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, _ = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
rel_pos=rel_pos,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.is_moe_layer:
|
||||
x = self.ffn(x)
|
||||
l_aux = None
|
||||
else:
|
||||
x = x.transpose(0, 1)
|
||||
x, l_aux = self.moe_layer(x)
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return x, l_aux
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
embed_tokens=None,
|
||||
embed_positions=None,
|
||||
output_projection=None,
|
||||
is_encoder_decoder=False,
|
||||
**kwargs
|
||||
):
|
||||
self.args = args
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
embed_dim = args.encoder_embed_dim
|
||||
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = embed_positions
|
||||
|
||||
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
|
||||
self.output_projection = self.build_output_projection(args)
|
||||
else:
|
||||
self.output_projection = output_projection
|
||||
|
||||
if args.layernorm_embedding:
|
||||
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
moe_freq = args.moe_freq
|
||||
for i in range(args.encoder_layers):
|
||||
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
||||
self.layers.append(
|
||||
self.build_encoder_layer(
|
||||
args,
|
||||
depth=i,
|
||||
is_moe_layer=is_moe_layer,
|
||||
is_encoder_decoder=is_encoder_decoder
|
||||
)
|
||||
)
|
||||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim))
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
||||
self.relative_position = RelativePositionBias(
|
||||
num_buckets=args.rel_pos_buckets,
|
||||
max_distance=args.max_rel_pos,
|
||||
n_heads=args.encoder_attention_heads,
|
||||
)
|
||||
else:
|
||||
self.relative_position = None
|
||||
|
||||
if args.bert_init:
|
||||
self.apply(init_bert_params)
|
||||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
|
||||
else:
|
||||
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
p.data.div_(init_scale)
|
||||
|
||||
if args.subln:
|
||||
if is_encoder_decoder:
|
||||
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
|
||||
else:
|
||||
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
def build_output_projection(
|
||||
self,
|
||||
args,
|
||||
):
|
||||
if args.share_encoder_input_output_embed:
|
||||
assert args.encoder_embedding_type == 'language'
|
||||
output_projection = torch.nn.Linear(
|
||||
self.embed_tokens.weight.shape[1],
|
||||
self.embed_tokens.weight.shape[0],
|
||||
bias=False,
|
||||
)
|
||||
output_projection.weight = self.embed_tokens.weight
|
||||
else:
|
||||
output_projection = torch.nn.Linear(
|
||||
args.encoder_embed_dim, args.vocab_size, bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5
|
||||
)
|
||||
return output_projection
|
||||
|
||||
def build_encoder_layer(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False
|
||||
):
|
||||
layer = EncoderLayer(
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=is_moe_layer,
|
||||
is_encoder_decoder=is_encoder_decoder
|
||||
)
|
||||
if args.checkpoint_activations:
|
||||
layer = checkpoint_wrapper(layer)
|
||||
if args.fsdp:
|
||||
layer = wrap(layer)
|
||||
return layer
|
||||
|
||||
def forward_embedding(
|
||||
self,
|
||||
src_tokens,
|
||||
token_embedding=None,
|
||||
):
|
||||
if token_embedding is None:
|
||||
token_embedding = self.embed_tokens(src_tokens)
|
||||
x = embed = self.embed_scale * token_embedding
|
||||
if self.embed_positions is not None:
|
||||
if src_tokens is not None:
|
||||
x = embed + self.embed_positions(src_tokens)
|
||||
else:
|
||||
x = embed + self.embed_positions(x)
|
||||
if self.layernorm_embedding is not None:
|
||||
x = self.layernorm_embedding(x)
|
||||
x = self.dropout_module(x)
|
||||
return x, embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
encoder_padding_mask=None,
|
||||
return_all_hiddens=False,
|
||||
token_embeddings=None,
|
||||
multiway_split_position=None,
|
||||
features_only=False,
|
||||
**kwargs
|
||||
):
|
||||
assert src_tokens is not None or token_embeddings is not None
|
||||
|
||||
if encoder_padding_mask is None:
|
||||
if src_tokens is not None:
|
||||
encoder_padding_mask = torch.zeros_like(
|
||||
src_tokens,
|
||||
device=src_tokens.device
|
||||
).bool()
|
||||
else:
|
||||
encoder_padding_mask = torch.zeros(
|
||||
[token_embeddings.size(0), token_embeddings.size(1)],
|
||||
device=token_embeddings.device
|
||||
).bool()
|
||||
|
||||
if multiway_split_position is not None:
|
||||
assert self.args.multiway
|
||||
self.apply(set_split_position(multiway_split_position))
|
||||
|
||||
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
encoder_states = []
|
||||
|
||||
if return_all_hiddens:
|
||||
encoder_states.append(x)
|
||||
|
||||
rel_pos_bias = None
|
||||
if self.relative_position is not None:
|
||||
rel_pos_bias = self.relative_position(
|
||||
batch_size=x.size(1),
|
||||
qlen=x.size(0),
|
||||
klen=x.size(0)
|
||||
)
|
||||
|
||||
l_aux = []
|
||||
for layer in self.layers:
|
||||
x, l_aux_i = layer(
|
||||
x, encoder_padding_mask=encoder_padding_mask,
|
||||
rel_pos=rel_pos_bias
|
||||
)
|
||||
if return_all_hiddens:
|
||||
assert encoder_states is not None
|
||||
encoder_states.append(x)
|
||||
l_aux.append(l_aux_i)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
if not features_only and self.output_projection is not None:
|
||||
x = self.output_projection(x)
|
||||
|
||||
return {
|
||||
"encoder_out": x,
|
||||
"encoder_embedding": encoder_embedding,
|
||||
"encoder_padding_mask": encoder_padding_mask,
|
||||
"encoder_states": encoder_states,
|
||||
"l_aux": l_aux,
|
||||
}
|
61
torchscale/architecture/encoder_decoder.py
Normal file
61
torchscale/architecture/encoder_decoder.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import torch.nn as nn
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
|
||||
|
||||
class EncoderDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
encoder_embed_tokens=None,
|
||||
encoder_embed_positions=None,
|
||||
decoder_embed_tokens=None,
|
||||
decoder_embed_positions=None,
|
||||
output_projection=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
if args.share_all_embeddings:
|
||||
args.share_decoder_input_output_embed = True
|
||||
|
||||
self.encoder = Encoder(
|
||||
args,
|
||||
encoder_embed_tokens,
|
||||
encoder_embed_positions,
|
||||
is_encoder_decoder=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if args.share_all_embeddings and decoder_embed_tokens is None:
|
||||
decoder_embed_tokens = self.encoder.embed_tokens
|
||||
|
||||
self.decoder = Decoder(
|
||||
args,
|
||||
decoder_embed_tokens,
|
||||
decoder_embed_positions,
|
||||
output_projection,
|
||||
is_encoder_decoder=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
prev_output_tokens,
|
||||
return_all_hiddens=False,
|
||||
features_only=False,
|
||||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
features_only=features_only,
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
return decoder_out
|
30
torchscale/architecture/utils.py
Normal file
30
torchscale/architecture/utils.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import torch.nn as nn
|
||||
from torchscale.component.multihead_attention import MultiheadAttention
|
||||
from torchscale.component.multiway_network import MultiwayNetwork
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
|
||||
def normal_(data):
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
if isinstance(module, nn.Embedding):
|
||||
normal_(module.weight.data)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
if isinstance(module, MultiheadAttention):
|
||||
if isinstance(module.q_proj, MultiwayNetwork):
|
||||
normal_(module.q_proj.A.weight.data)
|
||||
normal_(module.q_proj.B.weight.data)
|
||||
normal_(module.k_proj.A.weight.data)
|
||||
normal_(module.k_proj.B.weight.data)
|
||||
normal_(module.v_proj.A.weight.data)
|
||||
normal_(module.v_proj.B.weight.data)
|
||||
else:
|
||||
normal_(module.q_proj.weight.data)
|
||||
normal_(module.k_proj.weight.data)
|
||||
normal_(module.v_proj.weight.data)
|
0
torchscale/component/__init__.py
Normal file
0
torchscale/component/__init__.py
Normal file
16
torchscale/component/droppath.py
Normal file
16
torchscale/component/droppath.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from timm.models.layers import drop_path
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'p={}'.format(self.drop_prob)
|
120
torchscale/component/embedding.py
Normal file
120
torchscale/component/embedding.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class VisionLanguageEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_embed,
|
||||
vision_embed
|
||||
):
|
||||
super().__init__()
|
||||
self.text_embed = text_embed
|
||||
self.vision_embed = vision_embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
textual_tokens,
|
||||
visual_tokens,
|
||||
**kwargs
|
||||
):
|
||||
if textual_tokens is None:
|
||||
return self.vision_embed(visual_tokens)
|
||||
|
||||
if visual_tokens is None:
|
||||
return self.text_embed(textual_tokens)
|
||||
|
||||
x1 = self.vision_embed(visual_tokens)
|
||||
x2 = self.text_embed(textual_tokens)
|
||||
|
||||
return torch.cat([x1, x2], dim=1)
|
||||
|
||||
|
||||
class VisionEmbedding(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
contain_mask_token=False,
|
||||
prepend_cls_token=False
|
||||
):
|
||||
super().__init__()
|
||||
img_size = (img_size, img_size)
|
||||
patch_size = (patch_size, patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
if contain_mask_token:
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
else:
|
||||
self.mask_token = None
|
||||
|
||||
if prepend_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
else:
|
||||
self.cls_token = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
masked_position=None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
if masked_position is not None:
|
||||
assert self.mask_token is not None
|
||||
mask_token = self.mask_token.expand(batch_size, seq_len, -1)
|
||||
w = masked_position.unsqueeze(-1).type_as(mask_token)
|
||||
x = x * (1 - w) + mask_token * w
|
||||
|
||||
if self.cls_token is not None:
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TextEmbedding(nn.Embedding):
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Embedding):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
positions=None,
|
||||
**kwargs,
|
||||
):
|
||||
if positions is None:
|
||||
# being consistent with Fairseq, which starts from 2.
|
||||
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0)
|
||||
return F.embedding(
|
||||
positions,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
119
torchscale/component/feedforward_network.py
Normal file
119
torchscale/component/feedforward_network.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
|
||||
|
||||
class set_torch_seed(object):
|
||||
def __init__(self, seed):
|
||||
assert isinstance(seed, int)
|
||||
self.rng_state = self.get_rng_state()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def get_rng_state(self):
|
||||
state = {"torch_rng_state": torch.get_rng_state()}
|
||||
if torch.cuda.is_available():
|
||||
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
return state
|
||||
|
||||
def set_rng_state(self, state):
|
||||
torch.set_rng_state(state["torch_rng_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.set_rng_state(self.rng_state)
|
||||
|
||||
|
||||
def make_experts(args, embed_dim, expert_ffn_dim):
|
||||
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
|
||||
expert_list = []
|
||||
ddp_rank = args.ddp_rank
|
||||
start_seed = torch.randint(1000000, (1,)).item()
|
||||
# at least as many experts than gpus
|
||||
if args.moe_expert_count >= world_size:
|
||||
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}'
|
||||
local_moe_expert_count = args.moe_expert_count // world_size
|
||||
for i in range(local_moe_expert_count):
|
||||
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
||||
expert_list.append(
|
||||
FeedForwardNetwork(
|
||||
embed_dim,
|
||||
expert_ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.subln
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}'
|
||||
|
||||
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
||||
expert_list.append(
|
||||
FeedForwardNetwork(
|
||||
embed_dim,
|
||||
expert_ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.subln
|
||||
)
|
||||
)
|
||||
experts = nn.ModuleList(expert_list)
|
||||
return experts
|
||||
|
||||
|
||||
def get_activation_fn(activation):
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return lambda x: F.gelu(x.float()).type_as(x)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
activation_fn,
|
||||
dropout,
|
||||
activation_dropout,
|
||||
subln=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True)
|
||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
||||
self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None
|
||||
|
||||
def reset_parameters(self):
|
||||
self.fc1.reset_parameters()
|
||||
self.fc2.reset_parameters()
|
||||
if self.ffn_layernorm is not None:
|
||||
self.ffn_layernorm.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
x = self.fc1(x)
|
||||
x = self.activation_fn(x)
|
||||
x = self.activation_dropout_module(x)
|
||||
if self.ffn_layernorm is not None:
|
||||
x = self.ffn_layernorm(x)
|
||||
x = self.fc2(x)
|
||||
x = x.view(x_shape)
|
||||
x = self.dropout_module(x)
|
||||
return x
|
117
torchscale/component/multihead_attention.py
Normal file
117
torchscale/component/multihead_attention.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from .multiway_network import MultiwayWrapper
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
self_attention=False,
|
||||
encoder_decoder_attention=False,
|
||||
subln=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
assert self.self_attention ^ self.encoder_decoder_attention
|
||||
|
||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None
|
||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
incremental_state=None,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
rel_pos=None,
|
||||
):
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
src_len = tgt_len
|
||||
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
|
||||
src_len, key_bsz, _ = key.size()
|
||||
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
||||
assert value is not None
|
||||
assert src_len, bsz == value.shape[:2]
|
||||
|
||||
q = self.q_proj(query)
|
||||
k = self.k_proj(key)
|
||||
v = self.v_proj(value)
|
||||
q *= self.scaling
|
||||
|
||||
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||
|
||||
if incremental_state is not None:
|
||||
if "prev_key" in incremental_state:
|
||||
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
src_len = k.size(1)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_weights = torch.nan_to_num(attn_weights)
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if rel_pos is not None:
|
||||
rel_pos = rel_pos.view(attn_weights.size())
|
||||
attn_weights = attn_weights + rel_pos
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
|
||||
if self.inner_attn_ln is not None:
|
||||
attn = self.inner_attn_ln(attn)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
attn_weights = attn_weights.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
).transpose(1, 0)
|
||||
|
||||
return attn, attn_weights
|
39
torchscale/component/multiway_network.py
Normal file
39
torchscale/component/multiway_network.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def MultiwayWrapper(args, module, dim=0):
|
||||
if args.multiway:
|
||||
return MultiwayNetwork(module, dim=dim)
|
||||
return module
|
||||
|
||||
|
||||
def set_split_position(position):
|
||||
|
||||
def apply_fn(module):
|
||||
if hasattr(module, 'split_position'):
|
||||
module.split_position = position
|
||||
|
||||
return apply_fn
|
||||
|
||||
|
||||
class MultiwayNetwork(nn.Module):
|
||||
|
||||
def __init__(self, module, dim=0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.A = module
|
||||
self.B = copy.deepcopy(module)
|
||||
self.B.reset_parameters()
|
||||
self.split_position = -1
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
if self.split_position == -1:
|
||||
return self.A(x, **kwargs)
|
||||
if self.split_position == 0:
|
||||
return self.B(x, **kwargs)
|
||||
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim)
|
||||
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
||||
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
||||
return torch.cat([y1, y2], dim=self.dim)
|
79
torchscale/component/relative_position_bias.py
Normal file
79
torchscale/component/relative_position_bias.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
n_heads=12
|
||||
):
|
||||
super().__init__()
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = num_buckets
|
||||
self.max_distance = max_distance
|
||||
self.n_heads = n_heads
|
||||
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128
|
||||
):
|
||||
ret = 0
|
||||
n = -relative_position
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
ret += (n < 0).to(torch.long) * num_buckets
|
||||
n = torch.abs(n)
|
||||
else:
|
||||
n = torch.max(n, torch.zeros_like(n))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = n < max_exact
|
||||
|
||||
val_if_large = max_exact + (
|
||||
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
|
||||
ret += torch.where(is_small, n, val_if_large)
|
||||
return ret
|
||||
|
||||
def compute_bias(
|
||||
self,
|
||||
qlen,
|
||||
klen,
|
||||
step=None
|
||||
):
|
||||
step = 0 if step is None else step
|
||||
context_position = torch.arange(step, step + qlen, dtype=torch.long,
|
||||
device=self.relative_attention_bias.weight.device)[:, None]
|
||||
memory_position = torch.arange(klen, dtype=torch.long,
|
||||
device=self.relative_attention_bias.weight.device)[None, :]
|
||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||
|
||||
rp_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (qlen, klen)
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
)
|
||||
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size,
|
||||
qlen,
|
||||
klen,
|
||||
step=None
|
||||
):
|
||||
# shape (batch * num_heads, qlen, klen)
|
||||
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
|
0
torchscale/component/xmoe/__init__.py
Normal file
0
torchscale/component/xmoe/__init__.py
Normal file
310
torchscale/component/xmoe/moe_layer.py
Normal file
310
torchscale/component/xmoe/moe_layer.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# NOTE: This is a mirror of the code in
|
||||
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.modules.moe import MOELayer
|
||||
has_fairseq = True
|
||||
Base = MOELayer
|
||||
except ModuleNotFoundError:
|
||||
Base = Module
|
||||
has_fairseq = False
|
||||
|
||||
try:
|
||||
# To enable Tutel MoE optimizations:
|
||||
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
|
||||
from tutel import moe as tutel_moe
|
||||
|
||||
has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one
|
||||
except ModuleNotFoundError:
|
||||
has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
input = input.contiguous()
|
||||
output = torch.empty_like(input)
|
||||
if torch.distributed.is_initialized():
|
||||
dist.all_to_all_single(output, input, group=group)
|
||||
else:
|
||||
assert group is None
|
||||
output = input
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
def _find_my_group_index(grouped_ranks):
|
||||
my_rank = dist.get_rank()
|
||||
for i, group in enumerate(grouped_ranks):
|
||||
if my_rank in group:
|
||||
return i
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
def get_moe_group(moe_expert_count):
|
||||
if dist.is_initialized():
|
||||
if not hasattr(get_moe_group, "_moe_groups"):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if world_size <= moe_expert_count:
|
||||
assert moe_expert_count % world_size == 0
|
||||
moe_groups = [[i] for i in range(world_size)]
|
||||
|
||||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||
for i in range(moe_expert_count)]
|
||||
|
||||
get_moe_group._moe_group_idx = moe_groups
|
||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
|
||||
return get_moe_group._moe_groups[my_group_idx]
|
||||
|
||||
|
||||
def get_all2all_group(moe_expert_count):
|
||||
if dist.is_initialized():
|
||||
if not hasattr(get_all2all_group, "_all2all_groups"):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# more experts than world size
|
||||
if world_size <= moe_expert_count:
|
||||
assert moe_expert_count % world_size == 0
|
||||
all2all_groups = [[i for i in range(world_size)]]
|
||||
|
||||
# larger world than num experts
|
||||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||
for i in range(ranks_per_group)]
|
||||
|
||||
get_all2all_group._all2all_group_idx = all2all_groups
|
||||
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||
return get_all2all_group._all2all_groups[my_group_idx]
|
||||
|
||||
|
||||
class MOELayer(Base):
|
||||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||
::
|
||||
|
||||
gate = Top2Gate(model_dim, num_experts)
|
||||
moe = MOELayer(gate, expert)
|
||||
output = moe(input)
|
||||
l_aux = moe.l_aux
|
||||
|
||||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||
|
||||
Args:
|
||||
gate (torch.nn.Module):
|
||||
gate network
|
||||
expert (torch.nn.Module):
|
||||
expert network
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gate,
|
||||
experts,
|
||||
args
|
||||
):
|
||||
if has_fairseq:
|
||||
super(Base, self).__init__()
|
||||
else:
|
||||
super().__init__()
|
||||
self.gate = gate
|
||||
if type(experts) == ModuleList:
|
||||
self.experts = cast(ModuleList, experts)
|
||||
else:
|
||||
self.experts = ModuleList([experts])
|
||||
self.expert_group = get_moe_group(args.moe_expert_count)
|
||||
self.all2all_group = get_all2all_group(args.moe_expert_count)
|
||||
self.world_size = dist.get_world_size(group=self.expert_group)
|
||||
self.all2all_size = dist.get_world_size(group=self.all2all_group)
|
||||
for p in experts.parameters():
|
||||
p.expert = True # type: ignore
|
||||
self.num_local_experts = len(self.experts)
|
||||
self.args = args
|
||||
self.in_generation = False
|
||||
self.a2a_cuda_event_intervals = []
|
||||
self.a2a_cpu_time_ms = 0.0
|
||||
|
||||
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
|
||||
assert len(input) == 1, "only single input Tensor supported"
|
||||
input = input[0]
|
||||
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
|
||||
if input_padding_mask is not None:
|
||||
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken"
|
||||
assert input_padding_mask.shape[0] == input.shape[0]
|
||||
assert input_padding_mask.shape[1] == input.shape[1]
|
||||
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
|
||||
|
||||
# Implement Algorithm 2 from GShard paper.
|
||||
d_model = input.shape[2]
|
||||
# Pad to expected batch size
|
||||
input_shape = list(input.shape)
|
||||
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0)
|
||||
# This indicates that --batch-size or --max-sentences is not specified
|
||||
if expected_bsz is None:
|
||||
expected_bsz = 0
|
||||
# Note: Padding is not necessary at generation time at present
|
||||
# because all DDP workers process the same batch. Also, batch size at generation time
|
||||
# can be different from that present in the checkpoint state
|
||||
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz:
|
||||
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})")
|
||||
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
|
||||
padded_input = torch.zeros(
|
||||
(expected_bsz, input_shape[1], input_shape[2]),
|
||||
dtype=input.dtype, layout=input.layout, device=input.device)
|
||||
padded_input[:input_shape[0], :, :] = input
|
||||
input = padded_input
|
||||
|
||||
padded_input_padding_mask = torch.ones(
|
||||
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device
|
||||
)
|
||||
if input_padding_mask is not None:
|
||||
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask
|
||||
else:
|
||||
padded_input_padding_mask[:input_shape[0], :] = False
|
||||
input_padding_mask = padded_input_padding_mask
|
||||
|
||||
# Reshape into S tokens by dropping sequence dimension.
|
||||
reshaped_input = input.reshape(-1, d_model)
|
||||
reshaped_input_shape = reshaped_input.shape
|
||||
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None
|
||||
|
||||
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||
# Pro of --max-tokens: more flexible for MT variable sequence lengths
|
||||
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
|
||||
if expected_bsz == 0:
|
||||
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device)
|
||||
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
|
||||
expected_dim = int(expected_dim.item())
|
||||
padded_input = torch.zeros(
|
||||
(expected_dim, reshaped_input_shape[1]),
|
||||
dtype=input.dtype, layout=input.layout, device=input.device)
|
||||
padded_input[:reshaped_input_shape[0], :] = reshaped_input
|
||||
reshaped_input = padded_input
|
||||
|
||||
padded_input_padding_mask = torch.ones(
|
||||
(expected_dim,), dtype=torch.bool, device=padded_input.device
|
||||
)
|
||||
if reshaped_input_padding_mask is not None:
|
||||
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask
|
||||
else:
|
||||
padded_input_padding_mask[:reshaped_input_shape[0]] = False
|
||||
reshaped_input_padding_mask = padded_input_padding_mask
|
||||
|
||||
if has_tutel:
|
||||
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask)
|
||||
S, M = reshaped_input.size(0), reshaped_input.size(1)
|
||||
|
||||
if not hasattr(self, '_tutel_dispatcher'):
|
||||
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
|
||||
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
|
||||
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
|
||||
else:
|
||||
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask)
|
||||
|
||||
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S
|
||||
E, C, S = dispatch_mask.size()
|
||||
M = reshaped_input.size(1)
|
||||
assert reshaped_input.size() == (S, M)
|
||||
# einsum("sec,sm->ecm")
|
||||
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M
|
||||
|
||||
if self.all2all_size > 1:
|
||||
dispatched_input = self.all_to_all_wrapper(dispatched_input)
|
||||
|
||||
# Re-shape after all-to-all: ecm -> gecm
|
||||
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model)
|
||||
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
|
||||
expert_outputs = []
|
||||
for chunk, expert in zip(chunks, self.experts):
|
||||
expert_outputs += [expert(chunk)]
|
||||
expert_output = torch.cat(expert_outputs, dim=1)
|
||||
|
||||
if self.all2all_size > 1:
|
||||
expert_output = self.all_to_all_wrapper(expert_output)
|
||||
|
||||
# Re-shape back: gecm -> ecm
|
||||
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model)
|
||||
|
||||
if has_tutel:
|
||||
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M))
|
||||
else:
|
||||
# einsum("sec,ecm->sm")
|
||||
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M))
|
||||
|
||||
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||
combined_output = combined_output[:reshaped_input_shape[0], :]
|
||||
combined_output = combined_output.reshape(input.shape)
|
||||
combined_output = combined_output[:input_shape[0], :, :]
|
||||
|
||||
self.record_all_to_all_stats()
|
||||
|
||||
return combined_output, l_aux
|
||||
|
||||
def prepare_for_inference_(self):
|
||||
self.in_generation = True
|
||||
|
||||
def all_to_all_wrapper(self, input: Tensor):
|
||||
dummy_a2a = getattr(self.args, 'dummy_a2a', False)
|
||||
if dummy_a2a:
|
||||
input = input.contiguous()
|
||||
output = input.detach().clone()
|
||||
return input
|
||||
# always record times, since it is not a lot of overhead
|
||||
# if we do not log it we simply clear it off in record_all_to_all_stats
|
||||
cuda_start = torch.cuda.Event(enable_timing=True)
|
||||
cuda_end = torch.cuda.Event(enable_timing=True)
|
||||
cpu_start = time.time() * 1000
|
||||
cuda_start.record()
|
||||
output = _AllToAll.apply(self.all2all_group, input)
|
||||
cuda_end.record()
|
||||
cpu_end = time.time() * 1000
|
||||
self.a2a_cpu_time_ms += (cpu_end - cpu_start)
|
||||
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
|
||||
return output
|
||||
|
||||
def record_all_to_all_stats(self):
|
||||
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
|
||||
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False)
|
||||
if record_a2a_perf_stats:
|
||||
torch.cuda.synchronize()
|
||||
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
|
||||
a2a_cuda_time_ms = 0.0
|
||||
for ev_start, ev_end in self.a2a_cuda_event_intervals:
|
||||
a2a_cuda_time_ms += ev_start.elapsed_time(ev_end)
|
||||
self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms
|
||||
# reset stats
|
||||
self.a2a_cpu_time_ms = 0.0
|
||||
self.a2a_cuda_event_intervals = []
|
459
torchscale/component/xmoe/routing.py
Normal file
459
torchscale/component/xmoe/routing.py
Normal file
|
@ -0,0 +1,459 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf
|
||||
# Code is inspired by Top2GatingOnLogits from lingvo:
|
||||
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
|
||||
|
||||
# NOTE: This is a mirror of the code in
|
||||
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
||||
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .moe_layer import has_tutel, fused_cumsum_sub_one
|
||||
|
||||
# use a fixed temperature to compute balance loss
|
||||
TEMPERATURE_FOR_L_UAX = 0.07
|
||||
|
||||
# maximum capacity of 1 expert as a fraction of number of tokens in the batch
|
||||
# Note: setting this to 1.0 causes inference to significantly slow down
|
||||
EVAL_CAPACITY_TOKEN_FRACTION = 0.25
|
||||
|
||||
# logging
|
||||
SAMPLE_FRACTION = 0.2
|
||||
|
||||
|
||||
def top1gating(
|
||||
logits: torch.Tensor,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
use_fp32=False,
|
||||
capacity_factor=1.0,
|
||||
eval_mode=False,
|
||||
moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
|
||||
use_xmoe=False,
|
||||
gate_obj=None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Dict]:
|
||||
"""Implements Top2Gating on logits."""
|
||||
metadata = {}
|
||||
if use_fp32:
|
||||
orig_dtype = logits.dtype
|
||||
logits = logits.float()
|
||||
|
||||
gates = F.softmax(logits, dim=1)
|
||||
metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
|
||||
|
||||
# gates has shape of SE
|
||||
num_tokens = gates.shape[0]
|
||||
num_experts = gates.shape[1]
|
||||
if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
|
||||
capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
|
||||
else:
|
||||
# capacity = capacity_factor * S/E
|
||||
capacity = int(capacity_factor * math.ceil(num_tokens / num_experts))
|
||||
|
||||
# Create a mask for 1st's expert per token
|
||||
indices1_s = torch.argmax(gates, dim=1)
|
||||
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
|
||||
if input_mask is not None and input_mask.any():
|
||||
nonpadding = ~ input_mask
|
||||
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||
|
||||
# for logging (percent of tokens routed to each expert)
|
||||
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||
|
||||
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||
metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
|
||||
|
||||
gates1_s = (gates * mask1).sum(dim=1)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
locations1 = fused_cumsum_sub_one(mask1)
|
||||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.to(gates.dtype), dim=0)
|
||||
|
||||
l_aux = torch.mean(me * ce)
|
||||
l_aux = l_aux * num_experts * num_experts
|
||||
|
||||
if has_tutel:
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ]
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
mask1 = mask1 * torch.lt(locations1, capacity)
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
|
||||
# Calculate combine_weights and dispatch_mask
|
||||
gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
|
||||
# locations1_sc = num_tokens * capacity
|
||||
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
||||
combine1_sec = torch.bmm(
|
||||
# einsum("se,sc->sec")
|
||||
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
||||
)
|
||||
dispatch_mask = combine1_sec.bool()
|
||||
if use_fp32:
|
||||
return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata
|
||||
else:
|
||||
return l_aux, combine1_sec, dispatch_mask, metadata
|
||||
|
||||
|
||||
class Top1Gate(torch.nn.Module):
|
||||
"""Gate module which implements Top2Gating as described in Gshard_.
|
||||
::
|
||||
|
||||
gate = Top2Gate(model_dim, num_experts)
|
||||
l_aux, combine_weights, dispatch_mask = gate(input)
|
||||
|
||||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||
|
||||
Args:
|
||||
model_dim (int):
|
||||
size of model embedding dimension
|
||||
num_experts (ints):
|
||||
number of experts in model
|
||||
"""
|
||||
|
||||
wg: torch.nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dim: int,
|
||||
num_experts: int,
|
||||
use_fp32=False,
|
||||
input_noise_type=None,
|
||||
capacity_factor=1.0,
|
||||
moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
|
||||
use_xmoe=False,
|
||||
) -> None:
|
||||
# TODO: merge this to top2gate.py
|
||||
#
|
||||
super().__init__()
|
||||
|
||||
if not use_xmoe:
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||
else:
|
||||
self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
|
||||
wg = torch.empty(num_experts, 16)
|
||||
torch.nn.init.orthogonal_(wg, gain=0.32)
|
||||
self.register_parameter("wg", torch.nn.Parameter(wg))
|
||||
|
||||
self.use_xmoe = use_xmoe
|
||||
self.use_fp32 = use_fp32
|
||||
self.input_noise_type = input_noise_type
|
||||
self.capacity_factor = capacity_factor
|
||||
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
|
||||
|
||||
def forward(self, input, mask=None): # type: ignore
|
||||
if self.use_xmoe:
|
||||
input = self.wg_reduction(input)
|
||||
with torch.no_grad():
|
||||
wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
|
||||
self.wg.mul_(1.5 / wg_norm)
|
||||
logits = self._cosine(input, self.wg)
|
||||
logits = self._make_finite(logits)
|
||||
else:
|
||||
logits = self.wg(input)
|
||||
|
||||
return top1gating(
|
||||
logits,
|
||||
mask,
|
||||
use_fp32=self.use_fp32,
|
||||
capacity_factor=self.capacity_factor,
|
||||
eval_mode=not self.training,
|
||||
moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
|
||||
use_xmoe=self.use_xmoe,
|
||||
gate_obj=self,
|
||||
)
|
||||
|
||||
def _make_finite(self, scores):
|
||||
ok = scores.isfinite()
|
||||
if not ok.all():
|
||||
# NaNs here can break the assignment algorithm
|
||||
scores[~ok] = scores[ok].min()
|
||||
return scores
|
||||
|
||||
def _get_gating_temperature(self, eps=1e-4):
|
||||
if self.gating_t.data.item() < eps:
|
||||
return eps
|
||||
return self.gating_t
|
||||
|
||||
def _cosine(self, mat1, mat2, eps=1e-4):
|
||||
assert mat1.dim() == 2
|
||||
assert mat2.dim() == 2
|
||||
# mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
|
||||
mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
|
||||
return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
|
||||
|
||||
|
||||
gumbel_map: Dict[torch.device, Callable] = {}
|
||||
|
||||
|
||||
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||
gumbel = gumbel_map.get(device)
|
||||
if gumbel is None:
|
||||
one = torch.tensor(1.0, device=device)
|
||||
zero = torch.tensor(0.0, device=device)
|
||||
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
|
||||
gumbel_map[device] = gumbel
|
||||
return gumbel(shape)
|
||||
|
||||
|
||||
def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor:
|
||||
if unsqueeze_indices:
|
||||
indices = indices.unsqueeze(-1)
|
||||
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
|
||||
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype)
|
||||
output.scatter_(
|
||||
len(output.shape) - 1, indices, 1
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def entropy(probs):
|
||||
logits = torch.distributions.utils.probs_to_logits(probs)
|
||||
p_log_p = probs * logits
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
def top2gating(
|
||||
logits: torch.Tensor,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
use_fp32=False,
|
||||
second_expert_policy='sampling',
|
||||
normalize_gate_prob_before_dropping=False,
|
||||
eval_mode=False,
|
||||
moe_eval_capacity_token_fraction=0.25,
|
||||
batch_prioritized_routing=False,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Implements Top2Gating on logits."""
|
||||
metadata = {}
|
||||
if use_fp32:
|
||||
orig_dtype = logits.dtype
|
||||
logits = logits.float()
|
||||
gates = F.softmax(logits, dim=1)
|
||||
metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
|
||||
# gates has shape of SE
|
||||
num_tokens = gates.shape[0]
|
||||
num_experts = gates.shape[1]
|
||||
if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
|
||||
capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
|
||||
else:
|
||||
# capacity = 2S/E
|
||||
capacity = 2 * math.ceil(num_tokens / num_experts)
|
||||
|
||||
# Create a mask for 1st's expert per token
|
||||
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
|
||||
mask1 = one_hot(indices1_s, num_experts)
|
||||
if second_expert_policy == 'sampling':
|
||||
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
else:
|
||||
logits_w_noise = logits
|
||||
# Replace top-expert with min value
|
||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
|
||||
indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True)
|
||||
mask2 = one_hot(indices2_s, num_experts)
|
||||
gates1_s = (gates * mask1).sum(dim=1)
|
||||
gates2_s = (gates * mask2).sum(dim=1)
|
||||
|
||||
if normalize_gate_prob_before_dropping:
|
||||
# Normalize gate probabilities
|
||||
denom_s = gates1_s + gates2_s
|
||||
# Avoid divide-by-zero
|
||||
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
||||
gates1_s = gates1_s / denom_s
|
||||
gates2_s = gates2_s / denom_s
|
||||
|
||||
if second_expert_policy == 'random':
|
||||
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
|
||||
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
if input_mask is not None and input_mask.any():
|
||||
nonpadding = ~ input_mask
|
||||
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||
|
||||
if batch_prioritized_routing:
|
||||
# if batch_prioritized_routing:
|
||||
importance_scores = -1 * gates.max(dim=1)[0]
|
||||
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
|
||||
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
|
||||
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
|
||||
|
||||
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
|
||||
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
|
||||
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
|
||||
|
||||
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
|
||||
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2
|
||||
else:
|
||||
locations1 = fused_cumsum_sub_one(mask1)
|
||||
locations2 = fused_cumsum_sub_one(mask2)
|
||||
# Update 2nd's location by accounting for locations of 1st
|
||||
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.to(gates.dtype), dim=0)
|
||||
l_aux = torch.mean(me * ce)
|
||||
l_aux = l_aux * num_experts * num_experts
|
||||
|
||||
# for logging purposes
|
||||
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
|
||||
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
mask1_, mask2_ = mask1, mask2
|
||||
mask1 = mask1 * torch.lt(locations1, capacity)
|
||||
mask2 = mask2 * torch.lt(locations2, capacity)
|
||||
|
||||
# for logging (percent of tokens routed to each expert)
|
||||
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||
|
||||
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
|
||||
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||
|
||||
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||
metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
|
||||
|
||||
metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum()
|
||||
metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum()
|
||||
|
||||
if not normalize_gate_prob_before_dropping:
|
||||
# Normalize gate probabilities
|
||||
gates1_s = (gates * mask1).sum(dim=1)
|
||||
gates2_s = (gates * mask2).sum(dim=1)
|
||||
denom_s = gates1_s + gates2_s
|
||||
# Avoid divide-by-zero
|
||||
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
||||
gates1_s /= denom_s
|
||||
gates2_s /= denom_s
|
||||
|
||||
if has_tutel:
|
||||
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
||||
|
||||
# Calculate combine_weights and dispatch_mask
|
||||
gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
|
||||
gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se")
|
||||
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
||||
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
|
||||
combine1_sec = torch.bmm(
|
||||
# einsum("se,sc->sec")
|
||||
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
||||
)
|
||||
combine2_sec = torch.bmm(
|
||||
# einsum("se,sc->sec")
|
||||
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1)
|
||||
)
|
||||
combine_weights = combine1_sec + combine2_sec
|
||||
dispatch_mask = combine_weights.bool()
|
||||
if use_fp32:
|
||||
return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata
|
||||
else:
|
||||
return l_aux, combine_weights, dispatch_mask, metadata
|
||||
|
||||
|
||||
class Top2Gate(torch.nn.Module):
|
||||
"""Gate module which implements Top2Gating as described in Gshard_.
|
||||
::
|
||||
|
||||
gate = Top2Gate(model_dim, num_experts)
|
||||
l_aux, combine_weights, dispatch_mask = gate(input)
|
||||
|
||||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||
|
||||
Args:
|
||||
model_dim (int):
|
||||
size of model embedding dimension
|
||||
num_experts (ints):
|
||||
number of experts in model
|
||||
"""
|
||||
|
||||
wg: torch.nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dim: int,
|
||||
num_experts: int,
|
||||
use_fp32=False,
|
||||
second_expert_policy='sampling',
|
||||
normalize_gate_prob_before_dropping=False,
|
||||
moe_eval_capacity_token_fraction=0.25,
|
||||
batch_prioritized_routing=False,
|
||||
use_xmoe=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not use_xmoe:
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||
else:
|
||||
self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
|
||||
wg = torch.empty(num_experts, 16)
|
||||
torch.nn.init.orthogonal_(wg, gain=0.32)
|
||||
self.register_parameter("wg", torch.nn.Parameter(wg))
|
||||
self.use_fp32 = use_fp32
|
||||
self.second_expert_policy = second_expert_policy
|
||||
self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping
|
||||
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
|
||||
self.batch_prioritized_routing = batch_prioritized_routing
|
||||
self.use_xmoe = use_xmoe
|
||||
|
||||
def forward(self, input, mask=None): # type: ignore
|
||||
if self.use_xmoe:
|
||||
input = self.wg_reduction(input)
|
||||
with torch.no_grad():
|
||||
wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
|
||||
self.wg.mul_(1.5 / wg_norm)
|
||||
logits = self._cosine(input, self.wg)
|
||||
logits = self._make_finite(logits)
|
||||
else:
|
||||
logits = self.wg(input)
|
||||
return top2gating(
|
||||
logits,
|
||||
mask,
|
||||
use_fp32=self.use_fp32,
|
||||
second_expert_policy=self.second_expert_policy,
|
||||
normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping,
|
||||
eval_mode=not self.training,
|
||||
moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
|
||||
batch_prioritized_routing=self.batch_prioritized_routing,
|
||||
)
|
||||
|
||||
def _cosine(self, mat1, mat2, eps=1e-4):
|
||||
assert mat1.dim() == 2
|
||||
assert mat2.dim() == 2
|
||||
# mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
|
||||
mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
|
||||
return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
|
||||
|
||||
def _make_finite(self, scores):
|
||||
ok = scores.isfinite()
|
||||
if not ok.all():
|
||||
# NaNs here can break the assignment algorithm
|
||||
scores[~ok] = scores[ok].min()
|
||||
return scores
|
85
torchscale/model/BEiT3.py
Normal file
85
torchscale/model/BEiT3.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding
|
||||
from torchscale.component.multiway_network import MultiwayWrapper
|
||||
|
||||
|
||||
class BEiT3(nn.Module):
|
||||
|
||||
def __init__(self, args, **kwargs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
assert args.multiway
|
||||
assert args.vocab_size > 0
|
||||
assert not args.share_encoder_input_output_embed
|
||||
self.text_embed = TextEmbedding(
|
||||
args.vocab_size,
|
||||
args.encoder_embed_dim
|
||||
)
|
||||
self.vision_embed = VisionEmbedding(
|
||||
args.img_size,
|
||||
args.patch_size,
|
||||
args.in_chans,
|
||||
args.encoder_embed_dim,
|
||||
contain_mask_token=True,
|
||||
prepend_cls_token=True
|
||||
)
|
||||
embed_positions = MultiwayWrapper(
|
||||
args,
|
||||
PositionalEmbedding(
|
||||
args.max_source_positions,
|
||||
args.encoder_embed_dim
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
self.encoder = Encoder(
|
||||
args,
|
||||
embed_tokens=None,
|
||||
embed_positions=embed_positions,
|
||||
output_projection=None,
|
||||
is_encoder_decoder=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
textual_tokens=None,
|
||||
visual_tokens=None,
|
||||
text_padding_position=None,
|
||||
vision_masked_position=None,
|
||||
):
|
||||
assert textual_tokens is not None or visual_tokens is not None
|
||||
|
||||
if textual_tokens is None:
|
||||
x = self.vision_embed(visual_tokens, vision_masked_position)
|
||||
encoder_padding_mask = None
|
||||
multiway_split_position = -1
|
||||
elif visual_tokens is None:
|
||||
x = self.text_embed(textual_tokens)
|
||||
encoder_padding_mask = text_padding_position
|
||||
multiway_split_position = 0
|
||||
else:
|
||||
x1 = self.vision_embed(visual_tokens, vision_masked_position)
|
||||
multiway_split_position = x1.size(1)
|
||||
x2 = self.text_embed(textual_tokens)
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
|
||||
if text_padding_position is not None:
|
||||
encoder_padding_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
|
||||
text_padding_position
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
encoder_padding_mask = None
|
||||
|
||||
encoder_out = self.encoder(
|
||||
src_tokens=None,
|
||||
encoder_padding_mask=encoder_padding_mask,
|
||||
token_embeddings=x,
|
||||
multiway_split_position=multiway_split_position,
|
||||
)
|
||||
|
||||
return encoder_out
|
0
torchscale/model/__init__.py
Normal file
0
torchscale/model/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user