torchscale released
This commit is contained in:
parent
41f6ee5687
commit
ede048831f
162
.gitignore
vendored
162
.gitignore
vendored
|
@ -348,3 +348,165 @@ MigrationBackup/
|
||||||
|
|
||||||
# Ionide (cross platform F# VS Code tools) working folder
|
# Ionide (cross platform F# VS Code tools) working folder
|
||||||
.ionide/
|
.ionide/
|
||||||
|
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
171
README.md
171
README.md
|
@ -1,11 +1,176 @@
|
||||||
# OpenScale - Transformers at (any) Scale
|
# TorchScale - A Library for Transformers at (Any) Scale
|
||||||
|
|
||||||
Fundamental research to improve modeling generality and capability, as well as training stability and efficiency of scaling Transformers.
|
<p>
|
||||||
|
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
||||||
|
<a href="https://pypi.org/project/torchscale"><img alt="MIT License" src="https://badge.fury.io/py/torchscale.svg" /></a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
TorchScale is a PyTorch library that allows researchers and developeres to scale up Transformers efficiently and effectively.
|
||||||
|
It has the implemetention of fundamental research to improve modeling generality and capability, as well as training stability and efficiency of scaling Transformers.
|
||||||
|
|
||||||
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
|
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
|
||||||
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423)
|
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423)
|
||||||
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
|
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
|
||||||
|
|
||||||
|
## News
|
||||||
|
|
||||||
|
- November, 2022: TorchScale 0.1.1 released
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To install:
|
||||||
|
```
|
||||||
|
pip install torchscale
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively, you can develop it locally:
|
||||||
|
```
|
||||||
|
git clone https://github.com/microsoft/torchscale.git
|
||||||
|
cd torchscale
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from torchscale.architecture.config import EncoderConfig
|
||||||
|
>>> from torchscale.architecture.encoder import Encoder
|
||||||
|
|
||||||
|
>>> config = EncoderConfig(vocab_size=64000)
|
||||||
|
>>> model = Encoder(config)
|
||||||
|
|
||||||
|
>>> print(model)
|
||||||
|
```
|
||||||
|
|
||||||
|
We also support the `Decoder` architecture and the `EncoderDecoder` architecture:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Creating a decoder model
|
||||||
|
>>> from torchscale.architecture.config import DecoderConfig
|
||||||
|
>>> from torchscale.architecture.decoder import Decoder
|
||||||
|
|
||||||
|
>>> config = DecoderConfig(vocab_size=64000)
|
||||||
|
>>> decoder = Decoder(config)
|
||||||
|
>>> print(decoder)
|
||||||
|
|
||||||
|
# Creating a encoder-decoder model
|
||||||
|
>>> from torchscale.architecture.config import EncoderDecoderConfig
|
||||||
|
>>> from torchscale.architecture.encoder_decoder import EncoderDecoder
|
||||||
|
|
||||||
|
>>> config = EncoderDecoderConfig(vocab_size=64000)
|
||||||
|
>>> encdec = EncoderDecoder(config)
|
||||||
|
>>> print(encdec)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
We have the examples of how to use TorchScale in the following scenarios/tasks:
|
||||||
|
|
||||||
|
- Language
|
||||||
|
|
||||||
|
* [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining)
|
||||||
|
|
||||||
|
* [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation)
|
||||||
|
|
||||||
|
* [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining)
|
||||||
|
|
||||||
|
- Vision
|
||||||
|
|
||||||
|
* ViT/BEiT [In progress]
|
||||||
|
|
||||||
|
- Speech
|
||||||
|
|
||||||
|
- Multimodal
|
||||||
|
|
||||||
|
* [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
|
||||||
|
|
||||||
|
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
### Stability Evaluation
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="./assets/convergence.png" width="800"/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
|
||||||
|
|
||||||
|
### Scaling-up Experiments
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="./assets/scaling_curve.png" width="800"/>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository.
|
||||||
|
|
||||||
|
## Citations
|
||||||
|
|
||||||
|
If you find this repository useful, please consider citing our work:
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{deepnet,
|
||||||
|
author = {Hongyu Wang and
|
||||||
|
Shuming Ma and
|
||||||
|
Li Dong and
|
||||||
|
Shaohan Huang and
|
||||||
|
Dongdong Zhang and
|
||||||
|
Furu Wei},
|
||||||
|
title = {{DeepNet}: Scaling Transformers to 1,000 Layers},
|
||||||
|
journal = {CoRR},
|
||||||
|
volume = {abs/2203.00555},
|
||||||
|
year = {2022},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{magneto,
|
||||||
|
author = {Hongyu Wang and
|
||||||
|
Shuming Ma and
|
||||||
|
Shaohan Huang and
|
||||||
|
Li Dong and
|
||||||
|
Wenhui Wang and
|
||||||
|
Zhiliang Peng and
|
||||||
|
Yu Wu and
|
||||||
|
Payal Bajaj and
|
||||||
|
Saksham Singhal and
|
||||||
|
Alon Benhaim and
|
||||||
|
Barun Patra and
|
||||||
|
Zhun Liu and
|
||||||
|
Vishrav Chaudhary and
|
||||||
|
Xia Song and
|
||||||
|
Furu Wei},
|
||||||
|
title = {Foundation Transformers},
|
||||||
|
journal = {CoRR},
|
||||||
|
volume = {abs/2210.06423},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
@article{xmoe,
|
||||||
|
author = {Zewen Chi and
|
||||||
|
Li Dong and
|
||||||
|
Shaohan Huang and
|
||||||
|
Damai Dai and
|
||||||
|
Shuming Ma and
|
||||||
|
Barun Patra and
|
||||||
|
Saksham Singhal and
|
||||||
|
Payal Bajaj and
|
||||||
|
Xia Song and
|
||||||
|
Furu Wei},
|
||||||
|
title = {On the Representation Collapse of Sparse Mixture of Experts},
|
||||||
|
journal = {CoRR},
|
||||||
|
volume = {abs/2204.09179},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
@ -19,7 +184,7 @@ provided by the bot. You will only need to do this once across all repos using o
|
||||||
|
|
||||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
|
||||||
|
|
||||||
## Trademarks
|
## Trademarks
|
||||||
|
|
||||||
|
|
BIN
assets/convergence.png
Normal file
BIN
assets/convergence.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 98 KiB |
BIN
assets/scaling_curve.png
Normal file
BIN
assets/scaling_curve.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
233
examples/fairseq/README.md
Normal file
233
examples/fairseq/README.md
Normal file
|
@ -0,0 +1,233 @@
|
||||||
|
# Example: Integration with FairSeq
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install the repo as a package:
|
||||||
|
git clone https://github.com/msranlp/torchscale.git
|
||||||
|
cd torchscale
|
||||||
|
pip install -e .
|
||||||
|
pip install git+https://github.com/shumingma/fairseq.git@moe
|
||||||
|
pip install git+https://github.com/shumingma/infinibatch.git
|
||||||
|
pip install iopath
|
||||||
|
pip install --upgrade numpy
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: BERT Pretraining
|
||||||
|
|
||||||
|
### Data Format
|
||||||
|
|
||||||
|
We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files.
|
||||||
|
|
||||||
|
The overall data directory should be organized as follows:
|
||||||
|
```
|
||||||
|
Data/
|
||||||
|
├── json/
|
||||||
|
│ ├── train.json
|
||||||
|
│ └── valid.json
|
||||||
|
├── shard/
|
||||||
|
│ ├── train/
|
||||||
|
│ │ ├── 00000.txt
|
||||||
|
│ │ ├── 00001.txt
|
||||||
|
│ │ └── ...
|
||||||
|
│ └── valid/
|
||||||
|
│ ├── 00000.txt
|
||||||
|
│ ├── 00001.txt
|
||||||
|
│ └── ...
|
||||||
|
├── dict.txt
|
||||||
|
└── sentencepiece.bpe.model
|
||||||
|
```
|
||||||
|
|
||||||
|
We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line.
|
||||||
|
```
|
||||||
|
Document 1 Line 1
|
||||||
|
Document 1 Line 2
|
||||||
|
Document 1 Line 3
|
||||||
|
|
||||||
|
Document 2 Line 1
|
||||||
|
Document 2 Line 2
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Also, the JSON file should be in the format like this:
|
||||||
|
```
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"source": [
|
||||||
|
"shard/train/00000.txt",
|
||||||
|
"shard/train/00001.txt",
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"source_lang": "en",
|
||||||
|
"weight": 1.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Command
|
||||||
|
```bash
|
||||||
|
cd examples/fairseq/
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
|
||||||
|
--task pretraining \
|
||||||
|
--tokens-per-sample 512 \
|
||||||
|
--mask-prob 0.15 \
|
||||||
|
--span-length 3.0 \
|
||||||
|
--leave-unmasked-prob 0.0 \
|
||||||
|
--random-token-prob 0.0 \
|
||||||
|
--criterion masked_lm \
|
||||||
|
--arch mlm_base \
|
||||||
|
--share-encoder-input-output-embed \
|
||||||
|
--required-batch-size-multiple 8 \
|
||||||
|
--spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
|
||||||
|
--dict-file ${PATH_TO_DATA}/dict.txt \
|
||||||
|
--optimizer adam \
|
||||||
|
--adam-betas '(0.9,0.98)' \
|
||||||
|
--adam-eps 1e-6 \
|
||||||
|
--clip-norm 2.0 \
|
||||||
|
--lr-scheduler polynomial_decay \
|
||||||
|
--lr 0.0005 \
|
||||||
|
--warmup-updates 10000 \
|
||||||
|
--total-num-update 125000 \
|
||||||
|
--max-update 125000 \
|
||||||
|
--max-sentences 32 \
|
||||||
|
--update-freq 1 \
|
||||||
|
--log-format simple \
|
||||||
|
--log-interval 100 \
|
||||||
|
--disable-validation \
|
||||||
|
--save-interval-updates 5000 \
|
||||||
|
--no-epoch-checkpoints \
|
||||||
|
--fp16 \
|
||||||
|
--fp16-init-scale 4 \
|
||||||
|
--fp16-scale-window 256 \
|
||||||
|
--min-loss-scale 0.0001 \
|
||||||
|
--seed 1 \
|
||||||
|
--save-dir ${PATH_TO_CKPT} \
|
||||||
|
--ddp-backend=no_c10d \
|
||||||
|
--distributed-no-spawn \
|
||||||
|
--reset-dataloader \
|
||||||
|
--batch-read-ahead 10000 \
|
||||||
|
--rel-pos-buckets 32 \
|
||||||
|
--max-rel-pos 128 \
|
||||||
|
--deepnorm
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: GPT Pretraining
|
||||||
|
|
||||||
|
### Data Format
|
||||||
|
|
||||||
|
We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data).
|
||||||
|
|
||||||
|
### Dense Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd examples/fairseq/
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||||
|
${PATH_TO_DATA} \
|
||||||
|
--num-workers 2 \
|
||||||
|
--activation-fn gelu \
|
||||||
|
--share-decoder-input-output-embed \
|
||||||
|
--validate-interval-updates 1000 \
|
||||||
|
--save-interval-updates 1000 \
|
||||||
|
--no-epoch-checkpoints \
|
||||||
|
--memory-efficient-fp16 \
|
||||||
|
--fp16-init-scale 4 \
|
||||||
|
--arch lm_base \
|
||||||
|
--task language_modeling \
|
||||||
|
--sample-break-mode none \
|
||||||
|
--tokens-per-sample 128 \
|
||||||
|
--optimizer adam --adam-betas "(0.9, 0.98)" \
|
||||||
|
--adam-eps 1e-08 \
|
||||||
|
--clip-norm 0.0 \
|
||||||
|
--lr 5e-4 \
|
||||||
|
--lr-scheduler polynomial_decay \
|
||||||
|
--warmup-updates 750 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--attention-dropout 0.1 \
|
||||||
|
--weight-decay 0.01 \
|
||||||
|
--batch-size 4 \
|
||||||
|
--update-freq 1 \
|
||||||
|
--required-batch-size-multiple 1 \
|
||||||
|
--total-num-update 50000 \
|
||||||
|
--max-update 50000 \
|
||||||
|
--seed 1 \
|
||||||
|
--ddp-backend=c10d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sparse (MoE) Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd examples/fairseq/
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||||
|
${PATH_TO_DATA} \
|
||||||
|
--num-workers 2 \
|
||||||
|
--activation-fn gelu \
|
||||||
|
--share-decoder-input-output-embed \
|
||||||
|
--validate-interval-updates 1000 \
|
||||||
|
--save-interval-updates 1000 \
|
||||||
|
--no-epoch-checkpoints \
|
||||||
|
--memory-efficient-fp16 \
|
||||||
|
--fp16-init-scale 4 \
|
||||||
|
--arch lm_base \
|
||||||
|
--task language_modeling \
|
||||||
|
--sample-break-mode none \
|
||||||
|
--tokens-per-sample 128 \
|
||||||
|
--optimizer adam --adam-betas "(0.9, 0.98)" \
|
||||||
|
--adam-eps 1e-08 \
|
||||||
|
--clip-norm 0.0 \
|
||||||
|
--lr 5e-4 \
|
||||||
|
--lr-scheduler polynomial_decay \
|
||||||
|
--warmup-updates 750 \
|
||||||
|
--dropout 0.1 \
|
||||||
|
--attention-dropout 0.1 \
|
||||||
|
--weight-decay 0.01 \
|
||||||
|
--batch-size 4 \
|
||||||
|
--update-freq 1 \
|
||||||
|
--required-batch-size-multiple 1 \
|
||||||
|
--total-num-update 50000 \
|
||||||
|
--max-update 50000 \
|
||||||
|
--seed 1 \
|
||||||
|
--ddp-backend=no_c10d \
|
||||||
|
--moe-expert-count 2 --moe-freq 2 \
|
||||||
|
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||||
|
--moe-eval-capacity-token-fraction -1.0 \
|
||||||
|
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||||
|
--use-xmoe
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: Machine Translation
|
||||||
|
|
||||||
|
### Data Format
|
||||||
|
|
||||||
|
We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data.
|
||||||
|
|
||||||
|
### Dense Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd examples/fairseq/
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||||
|
${PATH_TO_DATA} \
|
||||||
|
--arch mt_base --share-decoder-input-output-embed \
|
||||||
|
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
||||||
|
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
||||||
|
--dropout 0.3 --weight-decay 0.0001 \
|
||||||
|
--max-tokens 4096 --fp16
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sparse (MoE) Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd examples/fairseq/
|
||||||
|
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
|
||||||
|
${PATH_TO_DATA} \
|
||||||
|
--arch mt_base --share-decoder-input-output-embed \
|
||||||
|
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
||||||
|
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
||||||
|
--dropout 0.3 --weight-decay 0.0001 \
|
||||||
|
--moe-expert-count 2 --moe-freq 2 \
|
||||||
|
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
|
||||||
|
--moe-eval-capacity-token-fraction -1.0 \
|
||||||
|
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
|
||||||
|
--use-xmoe \
|
||||||
|
--max-tokens 4096 --fp16
|
||||||
|
```
|
0
examples/fairseq/__init__.py
Normal file
0
examples/fairseq/__init__.py
Normal file
7
examples/fairseq/generate.py
Normal file
7
examples/fairseq/generate.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
import models
|
||||||
|
import tasks
|
||||||
|
|
||||||
|
from fairseq_cli.generate import cli_main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_main()
|
7
examples/fairseq/interactive.py
Normal file
7
examples/fairseq/interactive.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
import models
|
||||||
|
import tasks
|
||||||
|
|
||||||
|
from fairseq_cli.interactive import cli_main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_main()
|
33
examples/fairseq/models/__init__.py
Normal file
33
examples/fairseq/models/__init__.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
MODEL_REGISTRY = {}
|
||||||
|
MODEL_DATACLASS_REGISTRY = {}
|
||||||
|
ARCH_MODEL_REGISTRY = {}
|
||||||
|
ARCH_MODEL_NAME_REGISTRY = {}
|
||||||
|
ARCH_MODEL_INV_REGISTRY = {}
|
||||||
|
ARCH_CONFIG_REGISTRY = {}
|
||||||
|
|
||||||
|
# automatically import any Python files in the models/ directory
|
||||||
|
models_dir = os.path.dirname(__file__)
|
||||||
|
for file in os.listdir(models_dir):
|
||||||
|
path = os.path.join(models_dir, file)
|
||||||
|
if (
|
||||||
|
not file.startswith("_")
|
||||||
|
and not file.startswith(".")
|
||||||
|
and (file.endswith(".py") or os.path.isdir(path))
|
||||||
|
):
|
||||||
|
model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||||
|
module = importlib.import_module("models." + model_name)
|
||||||
|
|
||||||
|
# extra `model_parser` for sphinx
|
||||||
|
if model_name in MODEL_REGISTRY:
|
||||||
|
parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
group_archs = parser.add_argument_group("Named architectures")
|
||||||
|
group_archs.add_argument(
|
||||||
|
"--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
|
||||||
|
)
|
||||||
|
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||||
|
MODEL_REGISTRY[model_name].add_args(group_args)
|
||||||
|
globals()[model_name + "_parser"] = parser
|
459
examples/fairseq/models/bert.py
Normal file
459
examples/fairseq/models/bert.py
Normal file
|
@ -0,0 +1,459 @@
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairseq import utils
|
||||||
|
from fairseq.distributed import fsdp_wrap
|
||||||
|
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
|
||||||
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
|
from fairseq.models.transformer import (
|
||||||
|
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||||
|
)
|
||||||
|
from fairseq.modules import PositionalEmbedding
|
||||||
|
from fairseq.models.squad import SQuADHead
|
||||||
|
from torch import Tensor
|
||||||
|
from omegaconf import II
|
||||||
|
from .machine_translation import MTEncoder as Encoder
|
||||||
|
from torchscale.architecture.config import EncoderConfig
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
|
||||||
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BertConfig(FairseqDataclass):
|
||||||
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
default="relu", metadata={"help": "activation function to use"}
|
||||||
|
)
|
||||||
|
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
||||||
|
attention_dropout: float = field(
|
||||||
|
default=0.0, metadata={"help": "dropout probability for attention weights"}
|
||||||
|
)
|
||||||
|
activation_dropout: float = field(
|
||||||
|
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
||||||
|
)
|
||||||
|
encoder_embed_dim: int = field(
|
||||||
|
default=512, metadata={"help": "encoder embedding dimension"}
|
||||||
|
)
|
||||||
|
encoder_output_dim: int = field(
|
||||||
|
default=512, metadata={"help": "encoder output dimension"}
|
||||||
|
)
|
||||||
|
encoder_input_dim: int = field(
|
||||||
|
default=512, metadata={"help": "encoder input dimension"}
|
||||||
|
)
|
||||||
|
encoder_ffn_embed_dim: int = field(
|
||||||
|
default=2048, metadata={"help": "encoder embedding dimension for FFN"}
|
||||||
|
)
|
||||||
|
encoder_layers: int = field(default=6, metadata={"help": "num encoder layers"})
|
||||||
|
encoder_attention_heads: int = field(
|
||||||
|
default=8, metadata={"help": "num encoder attention heads"}
|
||||||
|
)
|
||||||
|
encoder_normalize_before: bool = field(
|
||||||
|
default=False, metadata={"help": "apply layernorm before each encoder block"}
|
||||||
|
)
|
||||||
|
no_encoder_final_norm: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "don't add an extra layernorm after the last encoder block"},
|
||||||
|
)
|
||||||
|
no_token_positional_embeddings: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "if set, disables positional embeddings (outside self attention)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
share_encoder_input_output_embed: bool = field(
|
||||||
|
default=False, metadata={"help": "share encoder input and output embeddings"}
|
||||||
|
)
|
||||||
|
encoder_learned_pos: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "use learned positional embeddings in the encoder"},
|
||||||
|
)
|
||||||
|
layernorm_embedding: bool = field(
|
||||||
|
default=False, metadata={"help": "add layernorm to embedding"}
|
||||||
|
)
|
||||||
|
no_scale_embedding: bool = field(
|
||||||
|
default=False, metadata={"help": "if True, dont scale embeddings"}
|
||||||
|
)
|
||||||
|
checkpoint_activations: bool = field(
|
||||||
|
default=False, metadata={"help": "checkpoint activations at each layer"}
|
||||||
|
)
|
||||||
|
offload_activations: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "move checkpointed activations to CPU after they are used."},
|
||||||
|
)
|
||||||
|
# config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
||||||
|
encoder_layerdrop: float = field(
|
||||||
|
default=0.0, metadata={"help": "LayerDrop probability for encoder"}
|
||||||
|
)
|
||||||
|
encoder_layers_to_keep: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "which layers to *keep* when pruning as a comma-separated list"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# config for Fully Sharded Data Parallel (FSDP) training
|
||||||
|
min_params_to_wrap: int = field(
|
||||||
|
default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"minimum number of params for a layer to be wrapped with FSDP() when "
|
||||||
|
"training with --ddp-backend=fully_sharded. Smaller values will "
|
||||||
|
"improve memory efficiency, but may make torch.distributed "
|
||||||
|
"communication less efficient due to smaller input sizes. This option "
|
||||||
|
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||||
|
"--offload-activations are passed."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
max_source_positions: int = field(
|
||||||
|
default=1024, metadata={"help": "max source positions"}
|
||||||
|
)
|
||||||
|
pooler_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
default="relu", metadata={"help": "activation function to use for pooler layer"}
|
||||||
|
)
|
||||||
|
pooler_dropout: float = field(
|
||||||
|
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
|
||||||
|
)
|
||||||
|
# options from other parts of the config
|
||||||
|
# add_bos_token: bool = II("task.add_bos_token")
|
||||||
|
# tokens_per_sample: int = II("task.tokens_per_sample")
|
||||||
|
tpu: bool = II("common.tpu")
|
||||||
|
rel_pos_buckets: int = field(
|
||||||
|
default=0, metadata={"help": ""}
|
||||||
|
)
|
||||||
|
max_rel_pos: int = field(
|
||||||
|
default=0, metadata={"help": ""}
|
||||||
|
)
|
||||||
|
moe_freq: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={
|
||||||
|
"help": "Frequency at which we insert MoE Transformer layers"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_expert_count: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of experts in each MoE Layer"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_gating_use_fp32: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use FP32 computations in MoE top2 gating function"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_second_expert_policy: str = field(
|
||||||
|
default='sampling',
|
||||||
|
metadata={
|
||||||
|
"help": "policy for second expert, options: all/sampling/random"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_expert_ffn_dim: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "MoE expert FFN dimension"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_top1_expert: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use top1 gate instead of top2"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
|
default=0.25,
|
||||||
|
metadata={
|
||||||
|
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
|
default='world_size',
|
||||||
|
metadata={
|
||||||
|
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
record_a2a_perf_stats: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||||
|
)
|
||||||
|
dummy_a2a: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||||
|
)
|
||||||
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||||
|
)
|
||||||
|
ddp_rank: int = II("distributed_training.distributed_rank")
|
||||||
|
deepnorm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
subln: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("mlm", dataclass=BertConfig)
|
||||||
|
class BertModel(BaseFairseqModel):
|
||||||
|
|
||||||
|
def __init__(self, args, encoder):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.encoder = encoder
|
||||||
|
self.padding_idx = self.encoder.embed_tokens.padding_idx
|
||||||
|
self.classification_heads = nn.ModuleDict()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args, task):
|
||||||
|
"""Build a new model instance."""
|
||||||
|
|
||||||
|
args.max_source_positions = getattr(
|
||||||
|
args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_tokens = cls.build_embedding(
|
||||||
|
args, task.dictionary, args.encoder_embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_positions = (
|
||||||
|
PositionalEmbedding(
|
||||||
|
args.max_source_positions,
|
||||||
|
args.encoder_embed_dim,
|
||||||
|
task.dictionary.pad(),
|
||||||
|
learned=args.encoder_learned_pos,
|
||||||
|
)
|
||||||
|
if not args.no_token_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
lm_head = cls.build_lm_head(
|
||||||
|
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
|
||||||
|
)
|
||||||
|
|
||||||
|
config = EncoderConfig()
|
||||||
|
config.override(args)
|
||||||
|
|
||||||
|
encoder = Encoder(
|
||||||
|
config,
|
||||||
|
embed_tokens=embed_tokens,
|
||||||
|
embed_positions=embed_positions,
|
||||||
|
output_projection=lm_head,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
dictionary=task.dictionary,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(args, encoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||||
|
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
|
||||||
|
return embed_tokens
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
|
||||||
|
return LMHead(embed_dim, output_dim, activation_fn, weight)
|
||||||
|
|
||||||
|
def output_layer(self, features, masked_tokens=None):
|
||||||
|
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
||||||
|
|
||||||
|
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
||||||
|
"""Register a classification head."""
|
||||||
|
if name in self.classification_heads:
|
||||||
|
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
||||||
|
prev_inner_dim = self.classification_heads[name].dense.out_features
|
||||||
|
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
||||||
|
logger.warning(
|
||||||
|
're-registering head "{}" with num_classes {} (prev: {}) '
|
||||||
|
'and inner_dim {} (prev: {})'.format(
|
||||||
|
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.classification_heads[name] = ClassificationHead(
|
||||||
|
self.args.encoder_embed_dim,
|
||||||
|
inner_dim or self.args.encoder_embed_dim,
|
||||||
|
num_classes,
|
||||||
|
self.args.pooler_activation_fn,
|
||||||
|
self.args.pooler_dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_question_answering_head(self, name, num_classes=None):
|
||||||
|
self.classification_heads[name] = SQuADHead(
|
||||||
|
self.args.encoder_embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upgrade_state_dict_named(self, state_dict, name):
|
||||||
|
prefix = name + '.' if name != '' else ''
|
||||||
|
|
||||||
|
# upgrade children modules
|
||||||
|
super().upgrade_state_dict_named(state_dict, name)
|
||||||
|
|
||||||
|
# Handle new classification heads present in the state dict.
|
||||||
|
current_head_names = (
|
||||||
|
[] if not hasattr(self, 'classification_heads')
|
||||||
|
else self.classification_heads.keys()
|
||||||
|
)
|
||||||
|
keys_to_delete = []
|
||||||
|
for k in state_dict.keys():
|
||||||
|
if not k.startswith(prefix + 'classification_heads.'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
|
||||||
|
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
|
||||||
|
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
|
||||||
|
|
||||||
|
if getattr(self.args, 'load_checkpoint_heads', False):
|
||||||
|
if head_name not in current_head_names:
|
||||||
|
self.register_classification_head(head_name, num_classes, inner_dim)
|
||||||
|
else:
|
||||||
|
if head_name not in current_head_names:
|
||||||
|
logger.warning(
|
||||||
|
'deleting classification head ({}) from checkpoint '
|
||||||
|
'not present in current model: {}'.format(head_name, k)
|
||||||
|
)
|
||||||
|
keys_to_delete.append(k)
|
||||||
|
elif (
|
||||||
|
num_classes != self.classification_heads[head_name].out_proj.out_features
|
||||||
|
or inner_dim != self.classification_heads[head_name].dense.out_features
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
'deleting classification head ({}) from checkpoint '
|
||||||
|
'with different dimensions than current model: {}'.format(head_name, k)
|
||||||
|
)
|
||||||
|
keys_to_delete.append(k)
|
||||||
|
for k in keys_to_delete:
|
||||||
|
del state_dict[k]
|
||||||
|
|
||||||
|
# Copy any newly-added classification heads into the state dict
|
||||||
|
# with their current weights.
|
||||||
|
if hasattr(self, 'classification_heads'):
|
||||||
|
cur_state = self.classification_heads.state_dict()
|
||||||
|
for k, v in cur_state.items():
|
||||||
|
if prefix + 'classification_heads.' + k not in state_dict:
|
||||||
|
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
||||||
|
state_dict[prefix + 'classification_heads.' + k] = v
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src_tokens=None,
|
||||||
|
features_only=False,
|
||||||
|
return_all_hiddens=False,
|
||||||
|
classification_head_name=None,
|
||||||
|
masked_tokens=None,
|
||||||
|
**kwargs):
|
||||||
|
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
||||||
|
x, extra = encoder_out["encoder_out"], encoder_out
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
if classification_head_name is not None:
|
||||||
|
x = self.classification_heads[classification_head_name](x)
|
||||||
|
elif not features_only:
|
||||||
|
x = self.output_layer(x, masked_tokens=masked_tokens)
|
||||||
|
|
||||||
|
return x, extra
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationHead(nn.Module):
|
||||||
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim,
|
||||||
|
inner_dim,
|
||||||
|
num_classes,
|
||||||
|
activation_fn,
|
||||||
|
pooler_dropout,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(input_dim, inner_dim)
|
||||||
|
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||||
|
self.dropout = nn.Dropout(p=pooler_dropout)
|
||||||
|
self.out_proj = nn.Linear(inner_dim, num_classes)
|
||||||
|
|
||||||
|
def forward(self, features, **kwargs):
|
||||||
|
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.dense(x)
|
||||||
|
x = self.activation_fn(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class LMHead(nn.Module):
|
||||||
|
"""Head for masked language modeling."""
|
||||||
|
|
||||||
|
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.activation_fn = utils.get_activation_fn(activation_fn)
|
||||||
|
self.layer_norm = LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
|
||||||
|
self.weight = weight
|
||||||
|
self.bias = nn.Parameter(torch.zeros(output_dim))
|
||||||
|
|
||||||
|
def forward(self, features, masked_tokens=None, **kwargs):
|
||||||
|
# Only project the masked tokens while training,
|
||||||
|
# saves both memory and computation
|
||||||
|
if masked_tokens is not None:
|
||||||
|
features = features[masked_tokens, :]
|
||||||
|
|
||||||
|
x = self.dense(features)
|
||||||
|
x = self.activation_fn(x)
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
# project back to size of vocabulary with bias
|
||||||
|
x = F.linear(x, self.weight) + self.bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@register_model_architecture("mlm", "mlm_base")
|
||||||
|
def base_unilm_architecture(args):
|
||||||
|
if hasattr(args, "encoder_final_norm"):
|
||||||
|
args.no_encoder_final_norm = not args.encoder_final_norm
|
||||||
|
|
||||||
|
args.dropout = getattr(args, "dropout", 0.1)
|
||||||
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||||
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||||
|
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||||
|
|
||||||
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
||||||
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
||||||
|
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||||
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
||||||
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
|
||||||
|
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||||
|
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||||
|
|
||||||
|
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
||||||
|
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||||
|
|
||||||
|
# args.add_bos_token = getattr(args, "add_bos_token", False)
|
||||||
|
args.no_token_positional_embeddings = getattr(
|
||||||
|
args, "no_token_positional_embeddings", False
|
||||||
|
)
|
||||||
|
args.share_encoder_input_output_embed = getattr(
|
||||||
|
args, "share_encoder_input_output_embed", True
|
||||||
|
)
|
||||||
|
args.encoder_output_dim = getattr(
|
||||||
|
args, "encoder_output_dim", args.encoder_embed_dim
|
||||||
|
)
|
||||||
|
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
|
||||||
|
|
||||||
|
# Model training is not stable without this
|
||||||
|
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
|
||||||
|
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
|
||||||
|
|
||||||
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
||||||
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
||||||
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||||
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
|
if args.offload_activations:
|
||||||
|
args.checkpoint_activations = True
|
357
examples/fairseq/models/language_modeling.py
Normal file
357
examples/fairseq/models/language_modeling.py
Normal file
|
@ -0,0 +1,357 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from fairseq import options, utils
|
||||||
|
from fairseq import distributed_utils
|
||||||
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
|
from fairseq.models import (
|
||||||
|
FairseqIncrementalDecoder,
|
||||||
|
FairseqLanguageModel,
|
||||||
|
register_model,
|
||||||
|
register_model_architecture,
|
||||||
|
)
|
||||||
|
from fairseq.models.transformer import (
|
||||||
|
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
|
||||||
|
)
|
||||||
|
from fairseq.modules import PositionalEmbedding
|
||||||
|
from torchscale.architecture.decoder import Decoder
|
||||||
|
from torchscale.architecture.config import DecoderConfig
|
||||||
|
from omegaconf import II
|
||||||
|
|
||||||
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LanguageConfig(FairseqDataclass):
|
||||||
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
default="relu", metadata={"help": "activation function to use"}
|
||||||
|
)
|
||||||
|
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
||||||
|
attention_dropout: float = field(
|
||||||
|
default=0.0, metadata={"help": "dropout probability for attention weights"}
|
||||||
|
)
|
||||||
|
activation_dropout: float = field(
|
||||||
|
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
||||||
|
)
|
||||||
|
relu_dropout: float = field(
|
||||||
|
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
||||||
|
)
|
||||||
|
decoder_embed_dim: int = field(
|
||||||
|
default=512, metadata={"help": "decoder embedding dimension"}
|
||||||
|
)
|
||||||
|
decoder_output_dim: int = field(
|
||||||
|
default=512, metadata={"help": "decoder output dimension"}
|
||||||
|
)
|
||||||
|
decoder_input_dim: int = field(
|
||||||
|
default=512, metadata={"help": "decoder input dimension"}
|
||||||
|
)
|
||||||
|
decoder_ffn_embed_dim: int = field(
|
||||||
|
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
|
||||||
|
)
|
||||||
|
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
|
||||||
|
decoder_attention_heads: int = field(
|
||||||
|
default=8, metadata={"help": "num decoder attention heads"}
|
||||||
|
)
|
||||||
|
decoder_normalize_before: bool = field(
|
||||||
|
default=False, metadata={"help": "apply layernorm before each decoder block"}
|
||||||
|
)
|
||||||
|
no_token_positional_embeddings: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "if set, disables positional embeddings (outside self attention)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
share_decoder_input_output_embed: bool = field(
|
||||||
|
default=False, metadata={"help": "share decoder input and output embeddings"}
|
||||||
|
)
|
||||||
|
decoder_learned_pos: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "use learned positional embeddings in the decoder"},
|
||||||
|
)
|
||||||
|
layernorm_embedding: bool = field(
|
||||||
|
default=False, metadata={"help": "add layernorm to embedding"}
|
||||||
|
)
|
||||||
|
no_scale_embedding: bool = field(
|
||||||
|
default=False, metadata={"help": "if True, dont scale embeddings"}
|
||||||
|
)
|
||||||
|
checkpoint_activations: bool = field(
|
||||||
|
default=False, metadata={"help": "checkpoint activations at each layer"}
|
||||||
|
)
|
||||||
|
offload_activations: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "move checkpointed activations to CPU after they are used."},
|
||||||
|
)
|
||||||
|
# config for Fully Sharded Data Parallel (FSDP) training
|
||||||
|
min_params_to_wrap: int = field(
|
||||||
|
default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"minimum number of params for a layer to be wrapped with FSDP() when "
|
||||||
|
"training with --ddp-backend=fully_sharded. Smaller values will "
|
||||||
|
"improve memory efficiency, but may make torch.distributed "
|
||||||
|
"communication less efficient due to smaller input sizes. This option "
|
||||||
|
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||||
|
"--offload-activations are passed."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_freq: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={
|
||||||
|
"help": "Frequency at which we insert MoE Transformer layers"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
moe_expert_count: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of experts in each MoE Layer"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_gating_use_fp32: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use FP32 computations in MoE top2 gating function"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_second_expert_policy: str = field(
|
||||||
|
default='sampling',
|
||||||
|
metadata={
|
||||||
|
"help": "policy for second expert, options: all/sampling/random"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_expert_ffn_dim: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "MoE expert FFN dimension"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_top1_expert: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use top1 gate instead of top2"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
|
default=0.25,
|
||||||
|
metadata={
|
||||||
|
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
|
default='world_size',
|
||||||
|
metadata={
|
||||||
|
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
record_a2a_perf_stats: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||||
|
)
|
||||||
|
dummy_a2a: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||||
|
)
|
||||||
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||||
|
)
|
||||||
|
use_xmoe: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# options from other parts of the config
|
||||||
|
add_bos_token: bool = II("task.add_bos_token")
|
||||||
|
tokens_per_sample: int = II("task.tokens_per_sample")
|
||||||
|
max_target_positions: Optional[int] = II("task.max_target_positions")
|
||||||
|
tpu: bool = II("common.tpu")
|
||||||
|
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
|
||||||
|
fp16: bool = II("common.fp16")
|
||||||
|
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
|
||||||
|
ddp_backend: str = II("distributed_training.ddp_backend")
|
||||||
|
world_size: int = II("distributed_training.distributed_world_size")
|
||||||
|
distributed_rank: int = II("distributed_training.distributed_rank")
|
||||||
|
ddp_rank: int = II("distributed_training.distributed_rank")
|
||||||
|
deepnorm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
subln: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
rel_pos_buckets: Optional[int] = field(
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
max_rel_pos: Optional[int] = field(
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("lm", dataclass=LanguageConfig)
|
||||||
|
class LanguageModel(FairseqLanguageModel):
|
||||||
|
|
||||||
|
def __init__(self, args, decoder):
|
||||||
|
self.args = args
|
||||||
|
super().__init__(decoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args, task):
|
||||||
|
|
||||||
|
if getattr(args, "max_target_positions", None) is None:
|
||||||
|
args.max_target_positions = getattr(
|
||||||
|
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_tokens = cls.build_embedding(
|
||||||
|
args, task.source_dictionary, args.decoder_embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_positions = (
|
||||||
|
PositionalEmbedding(
|
||||||
|
args.max_target_positions,
|
||||||
|
args.decoder_embed_dim,
|
||||||
|
task.dictionary.pad(),
|
||||||
|
learned=args.decoder_learned_pos,
|
||||||
|
)
|
||||||
|
if not args.no_token_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.share_decoder_input_output_embed:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
embed_tokens.weight.shape[1],
|
||||||
|
embed_tokens.weight.shape[0],
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
output_projection.weight = embed_tokens.weight
|
||||||
|
else:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
decoder_embed_dim, len(task.dictionary), bias=False
|
||||||
|
)
|
||||||
|
torch.nn.init.normal_(
|
||||||
|
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(args, 'moe_freq', 0) > 0
|
||||||
|
and (
|
||||||
|
getattr(args, 'fp16', False)
|
||||||
|
and not getattr(args, 'memory_efficient_fp16', False)
|
||||||
|
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||||
|
|
||||||
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||||
|
|
||||||
|
config = DecoderConfig()
|
||||||
|
config.override(args)
|
||||||
|
|
||||||
|
decoder = LMDecoder(
|
||||||
|
config,
|
||||||
|
embed_tokens,
|
||||||
|
embed_positions,
|
||||||
|
output_projection,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
dictionary=task.dictionary,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(args, decoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||||
|
return Embedding(len(dictionary), embed_dim, dictionary.pad())
|
||||||
|
|
||||||
|
|
||||||
|
class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||||
|
|
||||||
|
def forward(self, src_tokens, **kwargs):
|
||||||
|
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||||
|
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
|
||||||
|
|
||||||
|
def max_positions(self):
|
||||||
|
return self.embed_positions.max_positions
|
||||||
|
|
||||||
|
def reorder_incremental_state_scripting(
|
||||||
|
self,
|
||||||
|
incremental_state,
|
||||||
|
new_order,
|
||||||
|
):
|
||||||
|
for module in incremental_state:
|
||||||
|
for key in incremental_state[module]:
|
||||||
|
result = incremental_state[module][key].index_select(0, new_order)
|
||||||
|
incremental_state[module][key] = result
|
||||||
|
|
||||||
|
@register_model_architecture("lm", "lm_base")
|
||||||
|
def base_lm_architecture(args):
|
||||||
|
# backward compatibility for older model checkpoints
|
||||||
|
if hasattr(args, "no_tie_adaptive_proj"):
|
||||||
|
# previous models defined --no-tie-adaptive-proj, so use the existence of
|
||||||
|
# that option to determine if this is an "old" model checkpoint
|
||||||
|
args.no_decoder_final_norm = True # old models always set this to True
|
||||||
|
if args.no_tie_adaptive_proj is False:
|
||||||
|
args.tie_adaptive_proj = True
|
||||||
|
if hasattr(args, "decoder_final_norm"):
|
||||||
|
args.no_decoder_final_norm = not args.decoder_final_norm
|
||||||
|
|
||||||
|
args.dropout = getattr(args, "dropout", 0.1)
|
||||||
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||||
|
|
||||||
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
|
||||||
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||||
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||||
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||||
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||||
|
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
||||||
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||||
|
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||||
|
|
||||||
|
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
||||||
|
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||||
|
|
||||||
|
args.base_layers = getattr(args, "base_layers", 0)
|
||||||
|
args.base_sublayers = getattr(args, "base_sublayers", 1)
|
||||||
|
args.base_shuffle = getattr(args, "base_shuffle", False)
|
||||||
|
|
||||||
|
args.add_bos_token = getattr(args, "add_bos_token", False)
|
||||||
|
args.no_token_positional_embeddings = getattr(
|
||||||
|
args, "no_token_positional_embeddings", False
|
||||||
|
)
|
||||||
|
args.share_decoder_input_output_embed = getattr(
|
||||||
|
args, "share_decoder_input_output_embed", False
|
||||||
|
)
|
||||||
|
args.character_embeddings = getattr(args, "character_embeddings", False)
|
||||||
|
|
||||||
|
args.decoder_output_dim = getattr(
|
||||||
|
args, "decoder_output_dim", args.decoder_embed_dim
|
||||||
|
)
|
||||||
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||||
|
|
||||||
|
# Model training is not stable without this
|
||||||
|
args.decoder_normalize_before = True
|
||||||
|
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
|
||||||
|
|
||||||
|
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||||
|
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
|
||||||
|
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
|
||||||
|
|
||||||
|
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
||||||
|
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
||||||
|
|
||||||
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||||
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||||
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||||
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
|
if args.offload_activations:
|
||||||
|
args.checkpoint_activations = True
|
||||||
|
|
450
examples/fairseq/models/machine_translation.py
Normal file
450
examples/fairseq/models/machine_translation.py
Normal file
|
@ -0,0 +1,450 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import math
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from fairseq import utils
|
||||||
|
from fairseq.distributed import utils as dist_utils, fsdp_wrap
|
||||||
|
from fairseq import distributed_utils
|
||||||
|
from fairseq import checkpoint_utils
|
||||||
|
from fairseq.models import (
|
||||||
|
FairseqEncoder,
|
||||||
|
FairseqEncoderDecoderModel,
|
||||||
|
FairseqIncrementalDecoder,
|
||||||
|
register_model,
|
||||||
|
register_model_architecture,
|
||||||
|
)
|
||||||
|
from fairseq.models.transformer import Embedding
|
||||||
|
from fairseq.modules import (
|
||||||
|
AdaptiveSoftmax,
|
||||||
|
FairseqDropout,
|
||||||
|
LayerDropModuleList,
|
||||||
|
LayerNorm,
|
||||||
|
PositionalEmbedding,
|
||||||
|
SinusoidalPositionalEmbedding,
|
||||||
|
)
|
||||||
|
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
||||||
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
||||||
|
from .language_modeling import LMDecoder as MTDecoder
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||||
|
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("mt")
|
||||||
|
class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
|
|
||||||
|
def __init__(self, args, encoder, decoder):
|
||||||
|
super().__init__(encoder, decoder)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
"""Add model-specific arguments to the parser."""
|
||||||
|
# fmt: off
|
||||||
|
parser.add_argument('--activation-fn',
|
||||||
|
choices=utils.get_available_activation_fns(),
|
||||||
|
help='activation function to use')
|
||||||
|
parser.add_argument('--dropout', type=float, metavar='D',
|
||||||
|
help='dropout probability')
|
||||||
|
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
||||||
|
help='dropout probability for attention weights')
|
||||||
|
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
||||||
|
help='dropout probability after activation in FFN.')
|
||||||
|
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
||||||
|
help='path to pre-trained encoder embedding')
|
||||||
|
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
||||||
|
help='encoder embedding dimension')
|
||||||
|
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
||||||
|
help='encoder embedding dimension for FFN')
|
||||||
|
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
||||||
|
help='num encoder layers')
|
||||||
|
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
||||||
|
help='num encoder attention heads')
|
||||||
|
parser.add_argument('--encoder-normalize-before', action='store_true',
|
||||||
|
help='apply layernorm before each encoder block')
|
||||||
|
parser.add_argument('--encoder-learned-pos', action='store_true',
|
||||||
|
help='use learned positional embeddings in the encoder')
|
||||||
|
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
||||||
|
help='path to pre-trained decoder embedding')
|
||||||
|
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
||||||
|
help='decoder embedding dimension')
|
||||||
|
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
||||||
|
help='decoder embedding dimension for FFN')
|
||||||
|
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
||||||
|
help='num decoder layers')
|
||||||
|
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
||||||
|
help='num decoder attention heads')
|
||||||
|
parser.add_argument('--decoder-learned-pos', action='store_true',
|
||||||
|
help='use learned positional embeddings in the decoder')
|
||||||
|
parser.add_argument('--decoder-normalize-before', action='store_true',
|
||||||
|
help='apply layernorm before each decoder block')
|
||||||
|
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
|
||||||
|
help='decoder output dimension (extra linear layer '
|
||||||
|
'if different from decoder embed dim')
|
||||||
|
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
||||||
|
help='share decoder input and output embeddings')
|
||||||
|
parser.add_argument('--share-all-embeddings', action='store_true',
|
||||||
|
help='share encoder, decoder and output embeddings'
|
||||||
|
' (requires shared dictionary and embed dim)')
|
||||||
|
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
||||||
|
help='if set, disables positional embeddings (outside self attention)')
|
||||||
|
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
||||||
|
help='comma separated list of adaptive softmax cutoff points. '
|
||||||
|
'Must be used with adaptive_loss criterion'),
|
||||||
|
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
||||||
|
help='sets adaptive softmax dropout for the tail projections')
|
||||||
|
parser.add_argument('--layernorm-embedding', action='store_true',
|
||||||
|
help='add layernorm to embedding')
|
||||||
|
parser.add_argument('--no-scale-embedding', action='store_true',
|
||||||
|
help='if True, dont scale embeddings')
|
||||||
|
parser.add_argument('--checkpoint-activations', action='store_true',
|
||||||
|
help='checkpoint activations at each layer, which saves GPU '
|
||||||
|
'memory usage at the cost of some additional compute')
|
||||||
|
parser.add_argument('--offload-activations', action='store_true',
|
||||||
|
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
|
||||||
|
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
|
||||||
|
parser.add_argument('--no-cross-attention', default=False, action='store_true',
|
||||||
|
help='do not perform cross-attention')
|
||||||
|
parser.add_argument('--cross-self-attention', default=False, action='store_true',
|
||||||
|
help='perform cross+self-attention')
|
||||||
|
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
||||||
|
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
|
||||||
|
help='LayerDrop probability for encoder')
|
||||||
|
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
|
||||||
|
help='LayerDrop probability for decoder')
|
||||||
|
parser.add_argument('--encoder-layers-to-keep', default=None,
|
||||||
|
help='which layers to *keep* when pruning as a comma-separated list')
|
||||||
|
parser.add_argument('--decoder-layers-to-keep', default=None,
|
||||||
|
help='which layers to *keep* when pruning as a comma-separated list')
|
||||||
|
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
|
||||||
|
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
|
||||||
|
help='iterative PQ quantization noise at training time')
|
||||||
|
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
|
||||||
|
help='block size of quantization noise at training time')
|
||||||
|
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
|
||||||
|
help='scalar quantization noise and scalar quantization at training time')
|
||||||
|
# args for Fully Sharded Data Parallel (FSDP) training
|
||||||
|
parser.add_argument(
|
||||||
|
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||||
|
help=(
|
||||||
|
'minimum number of params for a layer to be wrapped with FSDP() when '
|
||||||
|
'training with --ddp-backend=fully_sharded. Smaller values will '
|
||||||
|
'improve memory efficiency, but may make torch.distributed '
|
||||||
|
'communication less efficient due to smaller input sizes. This option '
|
||||||
|
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
|
||||||
|
'--offload-activations are passed.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# args for mixture-of-expert layers
|
||||||
|
parser.add_argument('--moe-freq', type=int, metavar='D', default=0,
|
||||||
|
help='Frequency at which we insert MoE Transformer layers')
|
||||||
|
parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0,
|
||||||
|
help='Frequency at which we insert MoE Transformer encoder layers')
|
||||||
|
parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0,
|
||||||
|
help='Frequency at which we insert MoE Transformer decoder layers')
|
||||||
|
parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0,
|
||||||
|
help='Number of experts in each MoE Layer')
|
||||||
|
parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true',
|
||||||
|
help="Use FP32 computations in MoE top2 gating function")
|
||||||
|
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
||||||
|
help="policy for second expert, options: all/sampling/random")
|
||||||
|
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
||||||
|
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
|
||||||
|
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
||||||
|
help="MoE Expert FFN dimension")
|
||||||
|
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
||||||
|
help="Use top1 gate instead of top2")
|
||||||
|
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
||||||
|
help="Fraction of tokens as capacity during validation" + \
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0].")
|
||||||
|
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
||||||
|
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
||||||
|
|
||||||
|
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
||||||
|
help="Don't route padding tokens to any expert")
|
||||||
|
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
||||||
|
help="Enable X-Moe")
|
||||||
|
parser.add_argument('--freeze-moe', default=False, action='store_true',
|
||||||
|
help="Freeze MoE Params")
|
||||||
|
parser.add_argument('--deepnorm', default=False, action='store_true',
|
||||||
|
help="Enable DeepNorm")
|
||||||
|
parser.add_argument('--subln', default=False, action='store_true',
|
||||||
|
help="Enable SubLN")
|
||||||
|
parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='')
|
||||||
|
# args for pseudo-MoE layers
|
||||||
|
parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0,
|
||||||
|
help="FFN embed dim of alternate pseudo-MoE blocks")
|
||||||
|
parser.add_argument('--rel-pos-buckets', type=int, default=0,
|
||||||
|
help='')
|
||||||
|
parser.add_argument('--max-rel-pos', type=int, default=0,
|
||||||
|
help='')
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args, task):
|
||||||
|
"""Build a new model instance."""
|
||||||
|
|
||||||
|
# make sure all arguments are present in older models
|
||||||
|
base_architecture(args)
|
||||||
|
|
||||||
|
if getattr(args, "max_source_positions", None) is None:
|
||||||
|
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
||||||
|
if getattr(args, "max_target_positions", None) is None:
|
||||||
|
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
||||||
|
|
||||||
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||||
|
|
||||||
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
||||||
|
|
||||||
|
if args.share_all_embeddings:
|
||||||
|
if src_dict != tgt_dict:
|
||||||
|
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
||||||
|
if args.encoder_embed_dim != args.decoder_embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
||||||
|
)
|
||||||
|
if args.decoder_embed_path and (
|
||||||
|
args.decoder_embed_path != args.encoder_embed_path
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"--share-all-embeddings not compatible with --decoder-embed-path"
|
||||||
|
)
|
||||||
|
encoder_embed_tokens = cls.build_embedding(
|
||||||
|
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
||||||
|
)
|
||||||
|
decoder_embed_tokens = encoder_embed_tokens
|
||||||
|
args.share_decoder_input_output_embed = True
|
||||||
|
else:
|
||||||
|
encoder_embed_tokens = cls.build_embedding(
|
||||||
|
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
||||||
|
)
|
||||||
|
decoder_embed_tokens = cls.build_embedding(
|
||||||
|
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
||||||
|
)
|
||||||
|
if getattr(args, "offload_activations", False):
|
||||||
|
args.checkpoint_activations = True # offloading implies checkpointing
|
||||||
|
|
||||||
|
encoder_embed_positions = (
|
||||||
|
PositionalEmbedding(
|
||||||
|
args.max_source_positions,
|
||||||
|
args.encoder_embed_dim,
|
||||||
|
src_dict.pad(),
|
||||||
|
learned=args.encoder_learned_pos,
|
||||||
|
)
|
||||||
|
if not args.no_token_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_embed_positions = (
|
||||||
|
PositionalEmbedding(
|
||||||
|
args.max_target_positions,
|
||||||
|
args.decoder_embed_dim,
|
||||||
|
tgt_dict.pad(),
|
||||||
|
learned=args.decoder_learned_pos,
|
||||||
|
)
|
||||||
|
if not args.no_token_positional_embeddings
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.share_decoder_input_output_embed:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
decoder_embed_tokens.weight.shape[1],
|
||||||
|
decoder_embed_tokens.weight.shape[0],
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
output_projection.weight = decoder_embed_tokens.weight
|
||||||
|
else:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
decoder_embed_dim, len(tgt_dict), bias=False
|
||||||
|
)
|
||||||
|
torch.nn.init.normal_(
|
||||||
|
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder = cls.build_encoder(
|
||||||
|
args,
|
||||||
|
encoder_embed_tokens,
|
||||||
|
encoder_embed_positions,
|
||||||
|
src_dict,
|
||||||
|
)
|
||||||
|
decoder = cls.build_decoder(
|
||||||
|
args,
|
||||||
|
decoder_embed_tokens,
|
||||||
|
decoder_embed_positions,
|
||||||
|
output_projection,
|
||||||
|
tgt_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.share_all_embeddings:
|
||||||
|
min_params_to_wrap = getattr(
|
||||||
|
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
||||||
|
)
|
||||||
|
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
|
||||||
|
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
|
||||||
|
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
|
||||||
|
return cls(args, encoder, decoder)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||||
|
num_embeddings = len(dictionary)
|
||||||
|
padding_idx = dictionary.pad()
|
||||||
|
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
||||||
|
# if provided, load from preloaded dictionaries
|
||||||
|
if path:
|
||||||
|
embed_dict = utils.parse_embedding(path)
|
||||||
|
utils.load_embedding(embed_dict, dictionary, emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
||||||
|
config = EncoderConfig()
|
||||||
|
config.override(args)
|
||||||
|
|
||||||
|
return MTEncoder(
|
||||||
|
config,
|
||||||
|
embed_tokens,
|
||||||
|
embed_positions,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
dictionary=dictionary,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
|
||||||
|
config = DecoderConfig()
|
||||||
|
config.override(args)
|
||||||
|
|
||||||
|
return MTDecoder(
|
||||||
|
config,
|
||||||
|
embed_tokens,
|
||||||
|
embed_positions,
|
||||||
|
output_projection,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
dictionary=dictionary,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src_tokens,
|
||||||
|
src_lengths,
|
||||||
|
prev_output_tokens,
|
||||||
|
return_all_hiddens: bool = False,
|
||||||
|
features_only: bool = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_out = self.encoder(
|
||||||
|
src_tokens,
|
||||||
|
return_all_hiddens=return_all_hiddens
|
||||||
|
)
|
||||||
|
decoder_out = self.decoder(
|
||||||
|
prev_output_tokens,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
features_only=features_only,
|
||||||
|
return_all_hiddens=return_all_hiddens,
|
||||||
|
)
|
||||||
|
return decoder_out
|
||||||
|
|
||||||
|
def get_normalized_probs(
|
||||||
|
self,
|
||||||
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
||||||
|
log_probs: bool,
|
||||||
|
sample: Optional[Dict[str, Tensor]] = None,
|
||||||
|
):
|
||||||
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||||
|
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
||||||
|
|
||||||
|
|
||||||
|
class MTEncoder(Encoder, FairseqEncoder):
|
||||||
|
|
||||||
|
def forward(self, src_tokens, **kwargs):
|
||||||
|
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||||
|
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
|
||||||
|
|
||||||
|
def reorder_encoder_out(self, encoder_out, new_order):
|
||||||
|
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
|
||||||
|
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
|
||||||
|
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
|
||||||
|
|
||||||
|
encoder_states = encoder_out["encoder_states"]
|
||||||
|
if len(encoder_states) > 0:
|
||||||
|
for idx, state in enumerate(encoder_states):
|
||||||
|
encoder_states[idx] = state.index_select(1, new_order)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"encoder_out": new_encoder_out, # T x B x C
|
||||||
|
"encoder_padding_mask": new_encoder_padding_mask,
|
||||||
|
"encoder_embedding": new_encoder_embedding, # B x T x C
|
||||||
|
"encoder_states": encoder_states, # List[T x B x C]
|
||||||
|
}
|
||||||
|
|
||||||
|
def max_positions(self):
|
||||||
|
return self.embed_positions.max_positions
|
||||||
|
|
||||||
|
@register_model_architecture("mt", "mt_base")
|
||||||
|
def base_architecture(args):
|
||||||
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
||||||
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||||
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
||||||
|
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
||||||
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
||||||
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||||
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
||||||
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
||||||
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
||||||
|
args.decoder_ffn_embed_dim = getattr(
|
||||||
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
||||||
|
)
|
||||||
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||||
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||||
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
||||||
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||||
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||||
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
||||||
|
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||||
|
args.dropout = getattr(args, "dropout", 0.1)
|
||||||
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||||
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||||
|
args.share_decoder_input_output_embed = getattr(
|
||||||
|
args, "share_decoder_input_output_embed", False
|
||||||
|
)
|
||||||
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
||||||
|
args.no_token_positional_embeddings = getattr(
|
||||||
|
args, "no_token_positional_embeddings", False
|
||||||
|
)
|
||||||
|
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||||
|
args.no_cross_attention = getattr(args, "no_cross_attention", False)
|
||||||
|
args.cross_self_attention = getattr(args, "cross_self_attention", False)
|
||||||
|
|
||||||
|
args.decoder_output_dim = getattr(
|
||||||
|
args, "decoder_output_dim", args.decoder_embed_dim
|
||||||
|
)
|
||||||
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||||
|
|
||||||
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||||
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||||
|
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
||||||
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||||
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
|
if args.offload_activations:
|
||||||
|
args.checkpoint_activations = True
|
||||||
|
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||||
|
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||||
|
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
||||||
|
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
||||||
|
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
||||||
|
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
||||||
|
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
||||||
|
args.is_moe = getattr(args, "is_moe", False)
|
||||||
|
args.selected_expert_count = getattr(args, "selected_expert_count", 2)
|
32
examples/fairseq/tasks/__init__.py
Normal file
32
examples/fairseq/tasks/__init__.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
# register dataclass
|
||||||
|
TASK_DATACLASS_REGISTRY = {}
|
||||||
|
TASK_REGISTRY = {}
|
||||||
|
TASK_CLASS_NAMES = set()
|
||||||
|
|
||||||
|
# automatically import any Python files in the tasks/ directory
|
||||||
|
tasks_dir = os.path.dirname(__file__)
|
||||||
|
for file in os.listdir(tasks_dir):
|
||||||
|
path = os.path.join(tasks_dir, file)
|
||||||
|
if (
|
||||||
|
not file.startswith("_")
|
||||||
|
and not file.startswith(".")
|
||||||
|
and (file.endswith(".py") or os.path.isdir(path))
|
||||||
|
):
|
||||||
|
task_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||||
|
module = importlib.import_module("tasks." + task_name)
|
||||||
|
|
||||||
|
# expose `task_parser` for sphinx
|
||||||
|
if task_name in TASK_REGISTRY:
|
||||||
|
parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
group_task = parser.add_argument_group("Task name")
|
||||||
|
# fmt: off
|
||||||
|
group_task.add_argument('--task', metavar=task_name,
|
||||||
|
help='Enable this task with: ``--task=' + task_name + '``')
|
||||||
|
# fmt: on
|
||||||
|
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||||
|
TASK_REGISTRY[task_name].add_args(group_args)
|
||||||
|
globals()[task_name + "_parser"] = parser
|
0
examples/fairseq/tasks/data/__init__.py
Normal file
0
examples/fairseq/tasks/data/__init__.py
Normal file
78
examples/fairseq/tasks/data/basic_loader.py
Normal file
78
examples/fairseq/tasks/data/basic_loader.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from infinibatch.iterators import CheckpointableIterator
|
||||||
|
from . import utils
|
||||||
|
|
||||||
|
class BaseBatchGen(CheckpointableIterator):
|
||||||
|
"""
|
||||||
|
This is a base class for batch generators that use infinibatch
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._iter = None
|
||||||
|
self.epoch = 1
|
||||||
|
self.next_epoch_idx = 1
|
||||||
|
self.sharded_checkpoint = True
|
||||||
|
self.should_close_after_finished = True
|
||||||
|
|
||||||
|
def _build_iter(self):
|
||||||
|
"""
|
||||||
|
Build infinibatch iterator and assign to self._iter
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _move_to_tensor(self, batch):
|
||||||
|
|
||||||
|
def to_tensor(x):
|
||||||
|
return torch.tensor(x)
|
||||||
|
|
||||||
|
return utils.apply_to_sample(to_tensor, batch)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iterator(self):
|
||||||
|
if self._iter is None:
|
||||||
|
raise NotImplementedError("_build_iter() must called first")
|
||||||
|
return self._iter
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self._iter is None:
|
||||||
|
raise NotImplementedError("_build_iter() must called first")
|
||||||
|
return self._iter
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
return next(self._iter)
|
||||||
|
|
||||||
|
def setstate(self, value):
|
||||||
|
self._iter.setstate(value)
|
||||||
|
|
||||||
|
def getstate(self):
|
||||||
|
return self._iter.getstate()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._iter.close()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return 819200000
|
||||||
|
|
||||||
|
def next_epoch_itr(
|
||||||
|
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
|
||||||
|
):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def end_of_epoch(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""Returns a dictionary containing a whole state of the iterator."""
|
||||||
|
return self.getstate()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""Copies the state of the iterator from the given *state_dict*."""
|
||||||
|
self.setstate(state_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first_batch(self):
|
||||||
|
return "DUMMY"
|
308
examples/fairseq/tasks/data/mlm_loader.py
Normal file
308
examples/fairseq/tasks/data/mlm_loader.py
Normal file
|
@ -0,0 +1,308 @@
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import itertools
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from infinibatch import iterators
|
||||||
|
from .basic_loader import BaseBatchGen
|
||||||
|
from .utils import NativeCheckpointableIterator, WeightIterator
|
||||||
|
|
||||||
|
|
||||||
|
class MLMLoader(BaseBatchGen):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
dataset,
|
||||||
|
dictionary,
|
||||||
|
tokenizer,
|
||||||
|
max_tokens=None,
|
||||||
|
max_sentences=None,
|
||||||
|
max_positions=None,
|
||||||
|
ignore_invalid_inputs=False,
|
||||||
|
required_batch_size_multiple=1,
|
||||||
|
seed=1,
|
||||||
|
num_shards=1,
|
||||||
|
shard_id=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.data = dataset.data
|
||||||
|
self.data_dir = dataset.data_dir
|
||||||
|
self.shuffle = dataset.shuffle
|
||||||
|
self.dictionary = dictionary
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.max_sentences = max_sentences
|
||||||
|
self.max_positions = max_positions
|
||||||
|
self.tokens_per_sample = args.tokens_per_sample
|
||||||
|
self.sample_break_mode = args.sample_break_mode
|
||||||
|
self.ignore_invalid_inputs = ignore_invalid_inputs
|
||||||
|
self.required_batch_size_multiple = required_batch_size_multiple
|
||||||
|
self.seed = str(seed)
|
||||||
|
self.num_shards = num_shards
|
||||||
|
self.shard_id = shard_id
|
||||||
|
|
||||||
|
self.batch_read_ahead = args.batch_read_ahead
|
||||||
|
|
||||||
|
self._build_iter()
|
||||||
|
|
||||||
|
def _build_iter(self):
|
||||||
|
tokenized_lines = self._multilingual_tokenize()
|
||||||
|
self.padded_batches = self._batchify(tokenized_lines)
|
||||||
|
|
||||||
|
prefetch_batches = iterators.PrefetchIterator(
|
||||||
|
self.padded_batches,
|
||||||
|
buffer_size=10000,
|
||||||
|
buffer_in_main_process=True,
|
||||||
|
log_empty_buffer_warning=True and self.shard_id == 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
prefetch_batches = iterators.MapIterator(
|
||||||
|
prefetch_batches, self._move_to_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
self._iter = prefetch_batches
|
||||||
|
|
||||||
|
def _multilingual_tokenize(self):
|
||||||
|
multilingual_iters = []
|
||||||
|
weights = []
|
||||||
|
|
||||||
|
for data in self.data:
|
||||||
|
multilingual_iters.append(
|
||||||
|
self._tokenize(data)
|
||||||
|
)
|
||||||
|
if 'weight' in data:
|
||||||
|
weights.append(float(data['weight']))
|
||||||
|
else:
|
||||||
|
weights.append(int(data['count']))
|
||||||
|
|
||||||
|
if len(multilingual_iters) == 1:
|
||||||
|
return multilingual_iters[0]
|
||||||
|
|
||||||
|
sampling_iterator = WeightIterator(weights)
|
||||||
|
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
||||||
|
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
||||||
|
|
||||||
|
return tokenized_lines
|
||||||
|
|
||||||
|
def _tokenize(self, data):
|
||||||
|
'''
|
||||||
|
data:
|
||||||
|
{
|
||||||
|
'source': list[Path],
|
||||||
|
'source_lang': str,
|
||||||
|
'count': int,
|
||||||
|
'weight': float,
|
||||||
|
'name': str,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
dataset = list(
|
||||||
|
zip(
|
||||||
|
data['source'],
|
||||||
|
itertools.repeat(data['source_lang']),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.shuffle:
|
||||||
|
chunk_files = \
|
||||||
|
iterators.InfinitePermutationSourceIterator(
|
||||||
|
dataset,
|
||||||
|
seed=self.seed,
|
||||||
|
shuffle=self.shuffle,
|
||||||
|
num_instances=self.num_shards,
|
||||||
|
instance_rank=self.shard_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_files = \
|
||||||
|
iterators.ChunkedSourceIterator(
|
||||||
|
dataset,
|
||||||
|
num_instances=self.num_shards,
|
||||||
|
instance_rank=self.shard_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
||||||
|
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
||||||
|
|
||||||
|
return tokenized_lines
|
||||||
|
|
||||||
|
|
||||||
|
def _batchify(self, lines):
|
||||||
|
|
||||||
|
if self.max_sentences is not None:
|
||||||
|
if self.batch_read_ahead > 0:
|
||||||
|
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
||||||
|
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
|
||||||
|
else:
|
||||||
|
def dynamic_batch_size(sample):
|
||||||
|
lengths = [len(x) for x in sample]
|
||||||
|
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||||
|
return max(1, batch_size)
|
||||||
|
|
||||||
|
batches = iterators.BucketedReadaheadBatchIterator(
|
||||||
|
lines,
|
||||||
|
read_ahead=self.batch_read_ahead,
|
||||||
|
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||||
|
batch_size=dynamic_batch_size,
|
||||||
|
shuffle=self.shuffle,
|
||||||
|
seed=self.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def collate(batch):
|
||||||
|
batch_size = len(batch)
|
||||||
|
|
||||||
|
mlm_source_max_length = max([len(x[0]) for x in batch])
|
||||||
|
mlm_target_max_length = max([len(x[1]) for x in batch])
|
||||||
|
s2s_source_max_length = max([len(x[2]) for x in batch])
|
||||||
|
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||||
|
|
||||||
|
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad())
|
||||||
|
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad())
|
||||||
|
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad())
|
||||||
|
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad())
|
||||||
|
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad())
|
||||||
|
|
||||||
|
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
||||||
|
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
||||||
|
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
|
||||||
|
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
||||||
|
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
||||||
|
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
||||||
|
|
||||||
|
ret_batch = {
|
||||||
|
'net_input': {
|
||||||
|
'src_tokens': mlm_source_ids.astype(np.int64),
|
||||||
|
},
|
||||||
|
'target': mlm_target_ids.astype(np.int64),
|
||||||
|
'nsentences': batch_size,
|
||||||
|
'ntokens': sum([len(x[0]) for x in batch]),
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret_batch
|
||||||
|
|
||||||
|
padded_batches = iterators.MapIterator(
|
||||||
|
batches, collate
|
||||||
|
)
|
||||||
|
|
||||||
|
return padded_batches
|
||||||
|
|
||||||
|
def _prepare(self, _random, doc):
|
||||||
|
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
|
||||||
|
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
|
||||||
|
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
|
||||||
|
|
||||||
|
def _mask_lm(self, _random, doc):
|
||||||
|
def mask_tokens():
|
||||||
|
return f"<mask>"
|
||||||
|
|
||||||
|
length = len(doc)
|
||||||
|
mask_tokens_num = int(length * self.args.mask_prob)
|
||||||
|
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
|
||||||
|
possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
|
||||||
|
possible_mask_positions = sorted(possible_mask_positions)
|
||||||
|
|
||||||
|
nonmasked_tokens = copy.deepcopy(doc)
|
||||||
|
masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]
|
||||||
|
|
||||||
|
for position in possible_mask_positions:
|
||||||
|
# masked_tokens.append(nonmasked_tokens[position])
|
||||||
|
masked_tokens[position] = nonmasked_tokens[position]
|
||||||
|
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
|
||||||
|
|
||||||
|
return nonmasked_tokens, masked_tokens
|
||||||
|
|
||||||
|
def _span_corruption(self, _random, doc):
|
||||||
|
|
||||||
|
def mask_tokens(i):
|
||||||
|
return f"<mask_{i}>"
|
||||||
|
|
||||||
|
length = len(doc)
|
||||||
|
noise_tokens_num = int(length * self.args.mask_prob)
|
||||||
|
noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
|
||||||
|
noise_spans_num = int(noise_tokens_num / self.args.span_length)
|
||||||
|
noise_spans_num = max(noise_spans_num, 1)
|
||||||
|
nonnoise_tokens_num = length - noise_tokens_num
|
||||||
|
|
||||||
|
if noise_spans_num == 1:
|
||||||
|
noise_split_positions = [0, noise_tokens_num]
|
||||||
|
else:
|
||||||
|
possible_split_positions = list(range(1, noise_tokens_num))
|
||||||
|
_random.shuffle(possible_split_positions)
|
||||||
|
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
||||||
|
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
||||||
|
|
||||||
|
possible_insert_positions = list(range(nonnoise_tokens_num))
|
||||||
|
_random.shuffle(possible_insert_positions)
|
||||||
|
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
|
||||||
|
|
||||||
|
nonnoise_spans, noise_spans = [], []
|
||||||
|
last_end = 0
|
||||||
|
for i in range(noise_spans_num):
|
||||||
|
start_pos = noise_insert_positions[i] + noise_split_positions[i]
|
||||||
|
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
|
||||||
|
mask_id = self.dictionary.indices[mask_tokens(i)]
|
||||||
|
|
||||||
|
if getattr(self.args, "remove_target_sentinel", False):
|
||||||
|
noise_spans.append(doc[start_pos:end_pos])
|
||||||
|
else:
|
||||||
|
noise_spans.append([mask_id] + doc[start_pos:end_pos])
|
||||||
|
|
||||||
|
if getattr(self.args, "remove_source_sentinel", False):
|
||||||
|
nonnoise_spans.extend(doc[last_end:start_pos])
|
||||||
|
else:
|
||||||
|
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
|
||||||
|
|
||||||
|
last_end = end_pos
|
||||||
|
|
||||||
|
nonnoise_spans.extend(doc[last_end:])
|
||||||
|
noise_spans = sum(noise_spans, [])
|
||||||
|
|
||||||
|
return nonnoise_spans, noise_spans
|
||||||
|
|
||||||
|
def _read_from_files(self, source_file, source_lang):
|
||||||
|
# data = []
|
||||||
|
file_path = os.path.join(self.data_dir, source_file)
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
print('| file {} not exists'.format(file_path), flush=True)
|
||||||
|
return iter([]) # skip bad file
|
||||||
|
|
||||||
|
with open(file_path, 'r', encoding='utf8') as f:
|
||||||
|
lines = f.read().strip().split('\n')
|
||||||
|
|
||||||
|
doc = [self.dictionary.bos()]
|
||||||
|
for line in lines:
|
||||||
|
if line == "":
|
||||||
|
if self.sample_break_mode == 'complete_doc':
|
||||||
|
# data.append(doc)
|
||||||
|
yield doc
|
||||||
|
doc = [self.dictionary.bos()]
|
||||||
|
continue
|
||||||
|
|
||||||
|
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
||||||
|
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
||||||
|
|
||||||
|
if len(tokenized_id) > self.tokens_per_sample:
|
||||||
|
continue
|
||||||
|
if len(doc) + len(tokenized_id) > self.tokens_per_sample:
|
||||||
|
# data.append(doc)
|
||||||
|
yield doc
|
||||||
|
doc = [self.dictionary.bos()]
|
||||||
|
doc.extend(tokenized_id)
|
||||||
|
|
||||||
|
if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
|
||||||
|
# data.append(doc)
|
||||||
|
yield doc
|
||||||
|
|
||||||
|
# return data
|
82
examples/fairseq/tasks/data/utils.py
Normal file
82
examples/fairseq/tasks/data/utils.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import os
|
||||||
|
import gzip
|
||||||
|
import numpy as np
|
||||||
|
from random import Random
|
||||||
|
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
||||||
|
import collections
|
||||||
|
from infinibatch import iterators
|
||||||
|
|
||||||
|
def apply_to_sample(f, sample):
|
||||||
|
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _apply(x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
return f(x)
|
||||||
|
elif isinstance(x, collections.OrderedDict):
|
||||||
|
# OrderedDict has attributes that needs to be preserved
|
||||||
|
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
|
||||||
|
od.__dict__ = x.__dict__
|
||||||
|
return od
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
return {key: _apply(value) for key, value in x.items()}
|
||||||
|
elif isinstance(x, list):
|
||||||
|
return [_apply(x) for x in x]
|
||||||
|
elif isinstance(x, tuple):
|
||||||
|
return tuple(_apply(x) for x in x)
|
||||||
|
elif isinstance(x, set):
|
||||||
|
return {_apply(x) for x in x}
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
return _apply(sample)
|
||||||
|
|
||||||
|
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
|
def __init__(self, iterable: Iterable):
|
||||||
|
self._input_iterable = iterable
|
||||||
|
self.setstate(None)
|
||||||
|
|
||||||
|
def getstate(self) -> Dict:
|
||||||
|
return {'num_items_yielded': self._num_items_yielded}
|
||||||
|
|
||||||
|
def setstate(self, checkpoint: Optional[Dict]):
|
||||||
|
self._iterator = iter(self._input_iterable)
|
||||||
|
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
item = next(self._iterator)
|
||||||
|
self._num_items_yielded += 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WeightIterator(object):
|
||||||
|
def __init__(self, weights, seed):
|
||||||
|
self.weights = weights
|
||||||
|
self.seed = seed
|
||||||
|
self.control_index = list(range(len(weights)))
|
||||||
|
self.setstate(None)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def getstate(self):
|
||||||
|
return {"random_state": self._random_state}
|
||||||
|
|
||||||
|
def setstate(self, checkpoint):
|
||||||
|
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||||||
|
self._random = None # this will trigger the lazy initialization in self.__next__
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._random is None:
|
||||||
|
self._random = Random(self.seed)
|
||||||
|
if self._random_state is not None:
|
||||||
|
self._random.setstate(self._random_state)
|
||||||
|
idx = self._random.choices(self.control_index, self.weights)[0]
|
||||||
|
self._random_state = self._random.getstate()
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
pass
|
207
examples/fairseq/tasks/pretraining.py
Normal file
207
examples/fairseq/tasks/pretraining.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from argparse import Namespace
|
||||||
|
import json
|
||||||
|
from omegaconf import MISSING, II, OmegaConf
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fairseq import utils
|
||||||
|
from fairseq.data import Dictionary
|
||||||
|
from fairseq.tasks import FairseqTask, register_task
|
||||||
|
from .data.mlm_loader import MLMLoader
|
||||||
|
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||||
|
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||||
|
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PretrainingConfig(FairseqDataclass):
|
||||||
|
data: str = field(
|
||||||
|
default=MISSING,
|
||||||
|
metadata={
|
||||||
|
"help": "colon separated path to data directories list, \
|
||||||
|
will be iterated upon during epochs in round-robin manner"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
|
||||||
|
default="complete",
|
||||||
|
metadata={
|
||||||
|
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
|
||||||
|
'tokens. If set to "complete", splits samples only at the end '
|
||||||
|
"of sentence, but may include multiple sentences per sample. "
|
||||||
|
'"complete_doc" is similar but respects doc boundaries. '
|
||||||
|
'If set to "eos", includes only one sentence per sample.'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokens_per_sample: int = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||||
|
)
|
||||||
|
mask_prob: float = field(
|
||||||
|
default=0.15,
|
||||||
|
metadata={"help": "probability of replacing a token with mask"},
|
||||||
|
)
|
||||||
|
leave_unmasked_prob: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "probability that a masked token is unmasked"},
|
||||||
|
)
|
||||||
|
random_token_prob: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "probability of replacing a token with a random token"},
|
||||||
|
)
|
||||||
|
freq_weighted_replacement: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "sample random replacement words based on word frequencies"},
|
||||||
|
)
|
||||||
|
mask_whole_words: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "mask whole words; you may also want to set --bpe"},
|
||||||
|
)
|
||||||
|
mask_multiple_length: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "repeat the mask indices multiple times"},
|
||||||
|
)
|
||||||
|
mask_stdev: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "stdev of the mask length"},
|
||||||
|
)
|
||||||
|
shorten_method: SHORTEN_METHOD_CHOICES = field(
|
||||||
|
default="none",
|
||||||
|
metadata={
|
||||||
|
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
shorten_data_split_list: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": "comma-separated list of dataset splits to apply shortening to, "
|
||||||
|
'e.g., "train,valid" (default: all dataset splits)'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
seed: int = II("common.seed")
|
||||||
|
span_length: float = field(
|
||||||
|
default=3.0,
|
||||||
|
metadata={"help": "average span length for masking"},
|
||||||
|
)
|
||||||
|
remove_source_sentinel: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "remove the source sentinel for the span corruption task"},
|
||||||
|
)
|
||||||
|
remove_target_sentinel: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "remove the target sentinel for the span corruption task"},
|
||||||
|
)
|
||||||
|
batch_read_ahead: int = field(
|
||||||
|
default=100000,
|
||||||
|
metadata={"help": "batch read ahead size for infinibatch"},
|
||||||
|
)
|
||||||
|
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
|
||||||
|
spm_model: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": "sentencepice model to tokenize the data"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dict_file: str = field(
|
||||||
|
default="",
|
||||||
|
metadata={
|
||||||
|
"help": ""
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_task("pretraining", dataclass=PretrainingConfig)
|
||||||
|
class PLMTask(FairseqTask):
|
||||||
|
|
||||||
|
def __init__(self, cfg, dictionary, tokenizer):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.dictionary = dictionary
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.seed = cfg.seed
|
||||||
|
self.mask_idx = dictionary.index("<mask>")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_task(cls, cfg, **kwargs):
|
||||||
|
paths = utils.split_paths(cfg.data)
|
||||||
|
assert len(paths) > 0
|
||||||
|
if cfg.dict_file != "":
|
||||||
|
dictionary = Dictionary.load(cfg.dict_file)
|
||||||
|
else:
|
||||||
|
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||||
|
|
||||||
|
# add mask token
|
||||||
|
dictionary.add_symbol("<mask>")
|
||||||
|
for i in range(100):
|
||||||
|
dictionary.add_symbol(f"<mask_{i}>")
|
||||||
|
|
||||||
|
dictionary.pad_to_multiple_(cfg.required_batch_size_multiple)
|
||||||
|
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||||
|
|
||||||
|
# tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model))
|
||||||
|
tokenizer = spm.SentencePieceProcessor()
|
||||||
|
tokenizer.Load(cfg.spm_model)
|
||||||
|
return cls(cfg, dictionary, tokenizer)
|
||||||
|
|
||||||
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||||
|
self.datasets[split] = {
|
||||||
|
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
|
||||||
|
'data_dir': self.cfg.data,
|
||||||
|
'shuffle': True if split == 'train' else False,
|
||||||
|
}
|
||||||
|
self.datasets[split] = Namespace(**self.datasets[split])
|
||||||
|
|
||||||
|
def dataset(self, split):
|
||||||
|
if split not in self.datasets:
|
||||||
|
raise KeyError("Dataset not loaded: " + split)
|
||||||
|
|
||||||
|
return self.datasets[split]
|
||||||
|
|
||||||
|
def get_batch_iterator(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
max_tokens=None,
|
||||||
|
max_sentences=None,
|
||||||
|
max_positions=None,
|
||||||
|
ignore_invalid_inputs=False,
|
||||||
|
required_batch_size_multiple=1,
|
||||||
|
seed=1,
|
||||||
|
num_shards=1,
|
||||||
|
shard_id=0,
|
||||||
|
num_workers=0,
|
||||||
|
epoch=1,
|
||||||
|
data_buffer_size=0,
|
||||||
|
disable_iterator_cache=False,
|
||||||
|
):
|
||||||
|
return MLMLoader(
|
||||||
|
self.cfg,
|
||||||
|
dataset,
|
||||||
|
self.dictionary,
|
||||||
|
self.tokenizer,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
max_sentences=max_sentences,
|
||||||
|
max_positions=max_positions,
|
||||||
|
ignore_invalid_inputs=ignore_invalid_inputs,
|
||||||
|
required_batch_size_multiple=required_batch_size_multiple,
|
||||||
|
seed=seed,
|
||||||
|
num_shards=num_shards,
|
||||||
|
shard_id=shard_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_dictionary(self):
|
||||||
|
return self.dictionary
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target_dictionary(self):
|
||||||
|
return self.dictionary
|
8
examples/fairseq/train.py
Normal file
8
examples/fairseq/train.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
import models
|
||||||
|
import tasks
|
||||||
|
|
||||||
|
from fairseq_cli.train import cli_main
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli_main()
|
0
examples/fairseq/utils/__init__.py
Normal file
0
examples/fairseq/utils/__init__.py
Normal file
75
examples/fairseq/utils/sparse_clip.py
Normal file
75
examples/fairseq/utils/sparse_clip.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||||
|
import torch.distributed as dist
|
||||||
|
import math
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
||||||
|
def grad_exists(p):
|
||||||
|
return p is not None and getattr(p, "grad", None) is not None
|
||||||
|
if isinstance(params, torch.Tensor):
|
||||||
|
params = [params]
|
||||||
|
params = list(params)
|
||||||
|
params = list(filter(grad_exists, params))
|
||||||
|
grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
|
||||||
|
denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
|
||||||
|
for p in params:
|
||||||
|
if hasattr(p, "expert"):
|
||||||
|
expert_grads.append(p.grad.detach() / denom)
|
||||||
|
elif hasattr(p, "base_expert"):
|
||||||
|
base_expert_grads.append(p.grad.detach())
|
||||||
|
elif hasattr(p, "_is_sharded"):
|
||||||
|
sharded_grads.append(p.grad.detach())
|
||||||
|
else:
|
||||||
|
grads.append(p.grad.detach())
|
||||||
|
if len(grads) == 0:
|
||||||
|
if len(params) > 0:
|
||||||
|
total_norm = params[0].new_tensor(0.0)
|
||||||
|
else:
|
||||||
|
total_norm = torch.tensor(0.0)
|
||||||
|
elif len(grads) == 1:
|
||||||
|
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
if multi_tensor_l2norm_available:
|
||||||
|
total_norm = multi_tensor_total_norm(grads)
|
||||||
|
else:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
warnings.warn(
|
||||||
|
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
|
||||||
|
"you may get better performance by installing NVIDIA's apex library"
|
||||||
|
)
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
elif grads[0].device.type == "xla":
|
||||||
|
device = grads[0].device
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
total_norm = torch.norm(
|
||||||
|
torch.stack(
|
||||||
|
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate split_norm and all_reduce with other workers
|
||||||
|
norms = [total_norm]
|
||||||
|
for split_grads in [expert_grads, sharded_grads]:
|
||||||
|
if len(split_grads) == 0:
|
||||||
|
continue
|
||||||
|
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]))
|
||||||
|
if dist.is_initialized():
|
||||||
|
split_norm.pow_(2)
|
||||||
|
dist.all_reduce(split_norm)
|
||||||
|
split_norm.sqrt_()
|
||||||
|
norms.append(split_norm)
|
||||||
|
if len(norms) > 1:
|
||||||
|
total_norm = torch.norm(torch.stack(norms))
|
||||||
|
|
||||||
|
if aggregate_norm_fn is not None:
|
||||||
|
total_norm = aggregate_norm_fn(total_norm)
|
||||||
|
|
||||||
|
if max_norm > 0:
|
||||||
|
max_norm = float(max_norm)
|
||||||
|
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
||||||
|
for g in grads + expert_grads + sharded_grads + base_expert_grads:
|
||||||
|
g.mul_(clip_coef)
|
||||||
|
return total_norm
|
25
setup.py
Normal file
25
setup.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from io import open
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="torchscale",
|
||||||
|
version="0.1.1",
|
||||||
|
author="TorchScale Team",
|
||||||
|
author_email="Shuming.Ma@microsoft.com",
|
||||||
|
description="Transformers at any scale",
|
||||||
|
long_description=open("README.md", "r", encoding='utf-8').read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
keywords="Transformers at any scale",
|
||||||
|
license="MIT",
|
||||||
|
url="https://github.com/msranlp/torchscale",
|
||||||
|
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
||||||
|
"tests.*", "tests"]),
|
||||||
|
install_requires=['apex',
|
||||||
|
'torch>=1.8',
|
||||||
|
'fairscale==0.4.0',
|
||||||
|
'timm==0.4.12'],
|
||||||
|
python_requires='>=3.8.0',
|
||||||
|
classifiers=[
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
],
|
||||||
|
)
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
33
tests/test_decoder.py
Normal file
33
tests/test_decoder.py
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
import pytest
|
||||||
|
from torchscale.architecture.config import DecoderConfig
|
||||||
|
from torchscale.architecture.decoder import Decoder
|
||||||
|
import torch
|
||||||
|
|
||||||
|
testcases = [
|
||||||
|
{},
|
||||||
|
{"vocab_size": 64000},
|
||||||
|
{"activation_fn": "relu"},
|
||||||
|
{"drop_path_rate": 0.1},
|
||||||
|
{"decoder_normalize_before": False},
|
||||||
|
{"no_scale_embedding": False},
|
||||||
|
{"layernorm_embedding": True},
|
||||||
|
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||||
|
{"deepnorm": True, "subln": False, "decoder_normalize_before": False},
|
||||||
|
{"bert_init": True},
|
||||||
|
{"multiway": True},
|
||||||
|
{"share_decoder_input_output_embed": True},
|
||||||
|
{"checkpoint_activations": True},
|
||||||
|
{"fsdp": True}
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("args", testcases)
|
||||||
|
def test_decoder(args):
|
||||||
|
config = DecoderConfig(**args)
|
||||||
|
model = Decoder(config)
|
||||||
|
prev_output_tokens = torch.ones(2, 10)
|
||||||
|
token_embeddings = torch.rand(2, 10, config.decoder_embed_dim)
|
||||||
|
model(
|
||||||
|
prev_output_tokens=prev_output_tokens,
|
||||||
|
token_embeddings=token_embeddings,
|
||||||
|
features_only=True,
|
||||||
|
)
|
28
tests/test_encoder.py
Normal file
28
tests/test_encoder.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import pytest
|
||||||
|
from torchscale.architecture.config import EncoderConfig
|
||||||
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
import torch
|
||||||
|
|
||||||
|
testcases = [
|
||||||
|
{},
|
||||||
|
{"vocab_size": 64000},
|
||||||
|
{"activation_fn": "relu"},
|
||||||
|
{"drop_path_rate": 0.1},
|
||||||
|
{"encoder_normalize_before": False},
|
||||||
|
{"no_scale_embedding": False},
|
||||||
|
{"layernorm_embedding": True},
|
||||||
|
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||||
|
{"deepnorm": True, "subln": False, "encoder_normalize_before": False},
|
||||||
|
{"bert_init": True},
|
||||||
|
{"multiway": True},
|
||||||
|
{"share_encoder_input_output_embed": True},
|
||||||
|
{"checkpoint_activations": True},
|
||||||
|
{"fsdp": True}
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("args", testcases)
|
||||||
|
def test_encoder(args):
|
||||||
|
config = EncoderConfig(**args)
|
||||||
|
model = Encoder(config)
|
||||||
|
token_embeddings = torch.rand(2, 10, config.encoder_embed_dim)
|
||||||
|
model(src_tokens=None, token_embeddings=token_embeddings)
|
43
tests/test_encoder_decoder.py
Normal file
43
tests/test_encoder_decoder.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
from torchscale.architecture.config import EncoderDecoderConfig
|
||||||
|
from torchscale.architecture.encoder_decoder import EncoderDecoder
|
||||||
|
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding
|
||||||
|
import torch
|
||||||
|
|
||||||
|
testcases = [
|
||||||
|
{},
|
||||||
|
{"vocab_size": 64000},
|
||||||
|
{"activation_fn": "relu"},
|
||||||
|
{"drop_path_rate": 0.1},
|
||||||
|
{"encoder_normalize_before": False, "decoder_normalize_before": False},
|
||||||
|
{"no_scale_embedding": False},
|
||||||
|
{"layernorm_embedding": True},
|
||||||
|
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||||
|
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False},
|
||||||
|
{"bert_init": True},
|
||||||
|
{"multiway": True},
|
||||||
|
{"share_decoder_input_output_embed": True},
|
||||||
|
{"share_all_embeddings": True},
|
||||||
|
{"checkpoint_activations": True},
|
||||||
|
{"fsdp": True}
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("args", testcases)
|
||||||
|
def test_decoder(args):
|
||||||
|
config = EncoderDecoderConfig(**args)
|
||||||
|
model = EncoderDecoder(
|
||||||
|
config,
|
||||||
|
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||||
|
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||||
|
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
||||||
|
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
src_tokens = torch.ones(2, 20).long()
|
||||||
|
prev_output_tokens = torch.ones(2, 10).long()
|
||||||
|
|
||||||
|
model(
|
||||||
|
src_tokens=src_tokens,
|
||||||
|
prev_output_tokens=prev_output_tokens,
|
||||||
|
features_only=True,
|
||||||
|
)
|
0
torchscale/__init__.py
Normal file
0
torchscale/__init__.py
Normal file
0
torchscale/architecture/__init__.py
Normal file
0
torchscale/architecture/__init__.py
Normal file
160
torchscale/architecture/config.py
Normal file
160
torchscale/architecture/config.py
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderConfig(object):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||||
|
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||||
|
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||||
|
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||||
|
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||||
|
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||||
|
self.dropout = kwargs.pop("dropout", 0.0)
|
||||||
|
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||||
|
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||||
|
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||||
|
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||||
|
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||||
|
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||||
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
|
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||||
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||||
|
self.use_xmoe = kwargs.pop("use_xmoe", True)
|
||||||
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||||
|
self.subln = kwargs.pop("subln", True)
|
||||||
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
|
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
|
||||||
|
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||||
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
|
# Text
|
||||||
|
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||||
|
# Vision
|
||||||
|
self.img_size = kwargs.pop("img_size", 224)
|
||||||
|
self.patch_size = kwargs.pop("patch_size", 16)
|
||||||
|
self.in_chans = kwargs.pop("in_chans", 3)
|
||||||
|
# Fairscale
|
||||||
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
|
||||||
|
if self.deepnorm:
|
||||||
|
self.encoder_normalize_before = False
|
||||||
|
if self.subln:
|
||||||
|
self.encoder_normalize_before = True
|
||||||
|
|
||||||
|
def override(self, args):
|
||||||
|
for hp in self.__dict__.keys():
|
||||||
|
if getattr(args, hp, None) is not None:
|
||||||
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderConfig(object):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||||
|
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||||
|
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
||||||
|
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
||||||
|
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
||||||
|
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||||
|
self.dropout = kwargs.pop("dropout", 0.0)
|
||||||
|
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||||
|
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||||
|
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||||
|
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||||
|
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||||
|
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||||
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
|
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||||
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||||
|
self.use_xmoe = kwargs.pop("use_xmoe", True)
|
||||||
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||||
|
self.subln = kwargs.pop("subln", True)
|
||||||
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
|
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
||||||
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
|
# Text
|
||||||
|
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||||
|
# Fairscale
|
||||||
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
|
||||||
|
if self.deepnorm:
|
||||||
|
self.decoder_normalize_before = False
|
||||||
|
if self.subln:
|
||||||
|
self.decoder_normalize_before = True
|
||||||
|
|
||||||
|
def override(self, args):
|
||||||
|
for hp in self.__dict__.keys():
|
||||||
|
if getattr(args, hp, None) is not None:
|
||||||
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderConfig(object):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||||
|
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
|
||||||
|
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||||
|
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||||
|
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||||
|
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||||
|
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||||
|
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
||||||
|
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
||||||
|
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
||||||
|
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
||||||
|
self.dropout = kwargs.pop("dropout", 0.0)
|
||||||
|
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||||
|
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
|
||||||
|
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||||
|
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||||
|
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||||
|
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||||
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
|
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||||
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||||
|
self.use_xmoe = kwargs.pop("use_xmoe", True)
|
||||||
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
self.deepnorm = kwargs.pop("deepnorm", False)
|
||||||
|
self.subln = kwargs.pop("subln", True)
|
||||||
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
|
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
||||||
|
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
||||||
|
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||||
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
|
# Text
|
||||||
|
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||||
|
# Fairscale
|
||||||
|
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
||||||
|
self.fsdp = kwargs.pop("fsdp", False)
|
||||||
|
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
||||||
|
|
||||||
|
if self.deepnorm:
|
||||||
|
self.encoder_normalize_before = False
|
||||||
|
self.decoder_normalize_before = False
|
||||||
|
if self.subln:
|
||||||
|
self.encoder_normalize_before = True
|
||||||
|
self.decoder_normalize_before = True
|
||||||
|
|
||||||
|
def override(self, args):
|
||||||
|
for hp in self.__dict__.keys():
|
||||||
|
if getattr(args, hp, None) is not None:
|
||||||
|
self.__dict__[hp] = getattr(args, hp, None)
|
447
torchscale/architecture/decoder.py
Normal file
447
torchscale/architecture/decoder.py
Normal file
|
@ -0,0 +1,447 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||||
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
|
from torchscale.component.droppath import DropPath
|
||||||
|
from torchscale.architecture.utils import init_bert_params
|
||||||
|
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
depth,
|
||||||
|
is_moe_layer=False,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.embed_dim = args.decoder_embed_dim
|
||||||
|
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||||
|
|
||||||
|
if args.drop_path_rate > 0:
|
||||||
|
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
|
||||||
|
self.drop_path = DropPath(drop_path_prob)
|
||||||
|
else:
|
||||||
|
self.drop_path = None
|
||||||
|
|
||||||
|
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
||||||
|
|
||||||
|
self.normalize_before = args.decoder_normalize_before
|
||||||
|
|
||||||
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
if not is_encoder_decoder:
|
||||||
|
self.encoder_attn = None
|
||||||
|
self.encoder_attn_layer_norm = None
|
||||||
|
else:
|
||||||
|
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
||||||
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
self.is_moe_layer = is_moe_layer
|
||||||
|
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||||
|
|
||||||
|
if not self.is_moe_layer:
|
||||||
|
self.ffn = self.build_ffn(
|
||||||
|
self.embed_dim,
|
||||||
|
self.args,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if args.moe_top1_expert:
|
||||||
|
gate = Top1Gate(
|
||||||
|
self.embed_dim,
|
||||||
|
args.moe_expert_count,
|
||||||
|
use_fp32=args.moe_gating_use_fp32,
|
||||||
|
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
||||||
|
use_xmoe=args.use_xmoe,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gate = Top2Gate(
|
||||||
|
self.embed_dim,
|
||||||
|
args.moe_expert_count,
|
||||||
|
args.moe_gating_use_fp32,
|
||||||
|
args.moe_second_expert_policy,
|
||||||
|
args.moe_normalize_gate_prob_before_dropping,
|
||||||
|
args.moe_eval_capacity_token_fraction,
|
||||||
|
use_xmoe=args.use_xmoe,
|
||||||
|
)
|
||||||
|
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||||
|
self.moe_layer = MOELayer(gate, experts, args)
|
||||||
|
|
||||||
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||||
|
|
||||||
|
if args.deepnorm:
|
||||||
|
if is_encoder_decoder:
|
||||||
|
self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
|
||||||
|
else:
|
||||||
|
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
||||||
|
else:
|
||||||
|
self.alpha = 1.0
|
||||||
|
|
||||||
|
if args.subln:
|
||||||
|
self.ffn_layernorm = LayerNorm(self.ffn_dim)
|
||||||
|
else:
|
||||||
|
self.ffn_layernorm = None
|
||||||
|
|
||||||
|
def build_ffn(self, embed_dim, args):
|
||||||
|
return FeedForwardNetwork(
|
||||||
|
embed_dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
args.activation_fn,
|
||||||
|
args.dropout,
|
||||||
|
args.activation_dropout,
|
||||||
|
args.subln,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_self_attention(self, embed_dim, args):
|
||||||
|
return MultiheadAttention(
|
||||||
|
args,
|
||||||
|
embed_dim,
|
||||||
|
args.decoder_attention_heads,
|
||||||
|
dropout=args.attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
subln=args.subln,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_encoder_attention(self, embed_dim, args):
|
||||||
|
return MultiheadAttention(
|
||||||
|
args,
|
||||||
|
embed_dim,
|
||||||
|
args.decoder_attention_heads,
|
||||||
|
dropout=args.attention_dropout,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=True,
|
||||||
|
subln=args.subln,
|
||||||
|
)
|
||||||
|
|
||||||
|
def residual_connection(self, x, residual):
|
||||||
|
return residual * self.alpha + x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
encoder_out=None,
|
||||||
|
encoder_padding_mask=None,
|
||||||
|
incremental_state=None,
|
||||||
|
self_attn_mask=None,
|
||||||
|
self_attn_padding_mask=None,
|
||||||
|
self_attn_rel_pos=None,
|
||||||
|
cross_attn_rel_pos=None,
|
||||||
|
):
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
x, attn = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
incremental_state=incremental_state,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
rel_pos=self_attn_rel_pos,
|
||||||
|
)
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.residual_connection(x, residual)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
if self.encoder_attn is not None and encoder_out is not None:
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.encoder_attn_layer_norm(x)
|
||||||
|
|
||||||
|
x, attn = self.encoder_attn(
|
||||||
|
query=x,
|
||||||
|
key=encoder_out,
|
||||||
|
value=encoder_out,
|
||||||
|
key_padding_mask=encoder_padding_mask,
|
||||||
|
incremental_state=None,
|
||||||
|
rel_pos=cross_attn_rel_pos,
|
||||||
|
)
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.residual_connection(x, residual)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.encoder_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if not self.is_moe_layer:
|
||||||
|
x = self.ffn(x)
|
||||||
|
l_aux = None
|
||||||
|
else:
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
x, l_aux = self.moe_layer(x)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.residual_connection(x, residual)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
return x, attn, None, l_aux
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
embed_tokens=None,
|
||||||
|
embed_positions=None,
|
||||||
|
output_projection=None,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||||
|
|
||||||
|
embed_dim = args.decoder_embed_dim
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
||||||
|
|
||||||
|
self.embed_tokens = embed_tokens
|
||||||
|
self.embed_positions = embed_positions
|
||||||
|
|
||||||
|
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
|
||||||
|
self.output_projection = self.build_output_projection(args)
|
||||||
|
else:
|
||||||
|
self.output_projection = output_projection
|
||||||
|
|
||||||
|
if args.layernorm_embedding:
|
||||||
|
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||||
|
else:
|
||||||
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
moe_freq = args.moe_freq
|
||||||
|
for i in range(args.decoder_layers):
|
||||||
|
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
||||||
|
self.layers.append(
|
||||||
|
self.build_decoder_layer(
|
||||||
|
args,
|
||||||
|
depth=i,
|
||||||
|
is_moe_layer=is_moe_layer,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_layers = len(self.layers)
|
||||||
|
|
||||||
|
if args.decoder_normalize_before:
|
||||||
|
self.layer_norm = LayerNorm(embed_dim)
|
||||||
|
else:
|
||||||
|
self.layer_norm = None
|
||||||
|
|
||||||
|
self.output_projection = output_projection
|
||||||
|
|
||||||
|
self.self_attn_relative_position = None
|
||||||
|
self.cross_attn_relative_position = None
|
||||||
|
|
||||||
|
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
||||||
|
self.self_attn_relative_position = RelativePositionBias(
|
||||||
|
num_buckets=args.rel_pos_buckets,
|
||||||
|
max_distance=args.max_rel_pos,
|
||||||
|
n_heads=args.decoder_attention_heads,
|
||||||
|
)
|
||||||
|
if is_encoder_decoder:
|
||||||
|
self.cross_attn_relative_position = RelativePositionBias(
|
||||||
|
num_buckets=args.rel_pos_buckets,
|
||||||
|
max_distance=args.max_rel_pos,
|
||||||
|
n_heads=args.decoder_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.bert_init:
|
||||||
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
|
if args.deepnorm:
|
||||||
|
if is_encoder_decoder:
|
||||||
|
init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
|
||||||
|
else:
|
||||||
|
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
||||||
|
for name, p in self.named_parameters():
|
||||||
|
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||||
|
p.data.div_(init_scale)
|
||||||
|
|
||||||
|
if args.subln:
|
||||||
|
if is_encoder_decoder:
|
||||||
|
init_scale = math.sqrt(math.log(args.decoder_layers * 3))
|
||||||
|
else:
|
||||||
|
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
||||||
|
for name, p in self.named_parameters():
|
||||||
|
if 'encoder_attn' in name:
|
||||||
|
continue
|
||||||
|
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||||
|
p.data.mul_(init_scale)
|
||||||
|
|
||||||
|
def build_output_projection(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
):
|
||||||
|
if args.share_decoder_input_output_embed:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
self.embed_tokens.weight.shape[1],
|
||||||
|
self.embed_tokens.weight.shape[0],
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
output_projection.weight = self.embed_tokens.weight
|
||||||
|
else:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
args.decoder_embed_dim, args.vocab_size, bias=False
|
||||||
|
)
|
||||||
|
torch.nn.init.normal_(
|
||||||
|
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||||
|
)
|
||||||
|
return output_projection
|
||||||
|
|
||||||
|
def build_decoder_layer(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
depth,
|
||||||
|
is_moe_layer=False,
|
||||||
|
is_encoder_decoder=False
|
||||||
|
):
|
||||||
|
layer = DecoderLayer(
|
||||||
|
args,
|
||||||
|
depth,
|
||||||
|
is_moe_layer=is_moe_layer,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
)
|
||||||
|
if args.checkpoint_activations:
|
||||||
|
layer = checkpoint_wrapper(layer)
|
||||||
|
if args.fsdp:
|
||||||
|
layer = wrap(layer)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def forward_embedding(
|
||||||
|
self,
|
||||||
|
tokens,
|
||||||
|
token_embedding=None,
|
||||||
|
incremental_state=None,
|
||||||
|
):
|
||||||
|
positions = None
|
||||||
|
if self.embed_positions is not None:
|
||||||
|
positions = self.embed_positions(tokens, incremental_state=incremental_state)
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
tokens = tokens[:, -1:]
|
||||||
|
if positions is not None:
|
||||||
|
positions = positions[:, -1:]
|
||||||
|
|
||||||
|
if token_embedding is None:
|
||||||
|
token_embedding = self.embed_tokens(tokens)
|
||||||
|
|
||||||
|
x = embed = self.embed_scale * token_embedding
|
||||||
|
|
||||||
|
if positions is not None:
|
||||||
|
x += positions
|
||||||
|
|
||||||
|
if self.layernorm_embedding is not None:
|
||||||
|
x = self.layernorm_embedding(x)
|
||||||
|
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
return x, embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prev_output_tokens,
|
||||||
|
self_attn_padding_mask=None,
|
||||||
|
encoder_out=None,
|
||||||
|
incremental_state=None,
|
||||||
|
features_only=False,
|
||||||
|
return_all_hiddens=False,
|
||||||
|
token_embeddings=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# embed tokens and positions
|
||||||
|
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
# relative postion
|
||||||
|
self_attn_rel_pos_bias = None
|
||||||
|
slen = prev_output_tokens.size(1)
|
||||||
|
if self.self_attn_relative_position is not None:
|
||||||
|
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||||
|
batch_size=x.size(1),
|
||||||
|
qlen=slen,
|
||||||
|
klen=slen
|
||||||
|
)
|
||||||
|
if incremental_state is not None:
|
||||||
|
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
|
||||||
|
cross_attn_rel_pos_bias = None
|
||||||
|
if self.cross_attn_relative_position is not None:
|
||||||
|
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
|
||||||
|
batch_size=x.size(1),
|
||||||
|
qlen=slen,
|
||||||
|
klen=encoder_out["encoder_out"].size(0),
|
||||||
|
)
|
||||||
|
if incremental_state is not None:
|
||||||
|
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :]
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
inner_states = [x]
|
||||||
|
|
||||||
|
if encoder_out is None:
|
||||||
|
l_aux = []
|
||||||
|
else:
|
||||||
|
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.layers):
|
||||||
|
if incremental_state is None:
|
||||||
|
self_attn_mask = torch.triu(
|
||||||
|
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self_attn_mask = None
|
||||||
|
if idx not in incremental_state:
|
||||||
|
incremental_state[idx] = {}
|
||||||
|
|
||||||
|
x, layer_attn, _, l_aux_i = layer(
|
||||||
|
x,
|
||||||
|
encoder_out["encoder_out"] if encoder_out is not None else None,
|
||||||
|
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
|
||||||
|
incremental_state[idx] if incremental_state is not None else None,
|
||||||
|
self_attn_mask=self_attn_mask,
|
||||||
|
self_attn_padding_mask=self_attn_padding_mask,
|
||||||
|
self_attn_rel_pos=self_attn_rel_pos_bias,
|
||||||
|
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
||||||
|
)
|
||||||
|
l_aux.append(l_aux_i)
|
||||||
|
inner_states.append(x)
|
||||||
|
|
||||||
|
if self.layer_norm is not None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
if not features_only:
|
||||||
|
x = self.output_layer(x)
|
||||||
|
|
||||||
|
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]}
|
||||||
|
|
||||||
|
def output_layer(self, features):
|
||||||
|
return self.output_projection(features)
|
367
torchscale/architecture/encoder.py
Normal file
367
torchscale/architecture/encoder.py
Normal file
|
@ -0,0 +1,367 @@
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||||
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
|
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
|
||||||
|
from torchscale.component.droppath import DropPath
|
||||||
|
from torchscale.architecture.utils import init_bert_params
|
||||||
|
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
depth,
|
||||||
|
is_moe_layer=False,
|
||||||
|
is_encoder_decoder=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
self.embed_dim = args.encoder_embed_dim
|
||||||
|
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
||||||
|
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||||
|
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||||
|
|
||||||
|
if args.drop_path_rate > 0:
|
||||||
|
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
|
||||||
|
self.drop_path = DropPath(drop_path_prob)
|
||||||
|
else:
|
||||||
|
self.drop_path = None
|
||||||
|
|
||||||
|
self.normalize_before = args.encoder_normalize_before
|
||||||
|
self.is_moe_layer = is_moe_layer
|
||||||
|
self.ffn_dim = args.encoder_ffn_embed_dim
|
||||||
|
|
||||||
|
if not self.is_moe_layer:
|
||||||
|
self.ffn = MultiwayWrapper(
|
||||||
|
args,
|
||||||
|
self.build_ffn(
|
||||||
|
self.embed_dim,
|
||||||
|
self.args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert not self.args.multiway
|
||||||
|
if args.moe_top1_expert:
|
||||||
|
gate = Top1Gate(
|
||||||
|
self.embed_dim,
|
||||||
|
args.moe_expert_count,
|
||||||
|
use_fp32=args.moe_gating_use_fp32,
|
||||||
|
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
||||||
|
use_xmoe=args.use_xmoe,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gate = Top2Gate(
|
||||||
|
self.embed_dim,
|
||||||
|
args.moe_expert_count,
|
||||||
|
args.moe_gating_use_fp32,
|
||||||
|
args.moe_second_expert_policy,
|
||||||
|
args.moe_normalize_gate_prob_before_dropping,
|
||||||
|
args.moe_eval_capacity_token_fraction,
|
||||||
|
use_xmoe=args.use_xmoe,
|
||||||
|
)
|
||||||
|
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||||
|
self.moe_layer = MOELayer(gate, experts, args)
|
||||||
|
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||||
|
|
||||||
|
if args.deepnorm:
|
||||||
|
if is_encoder_decoder:
|
||||||
|
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
|
||||||
|
else:
|
||||||
|
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
||||||
|
else:
|
||||||
|
self.alpha = 1.0
|
||||||
|
|
||||||
|
def build_ffn(self, embed_dim, args):
|
||||||
|
return FeedForwardNetwork(
|
||||||
|
embed_dim,
|
||||||
|
self.ffn_dim,
|
||||||
|
args.activation_fn,
|
||||||
|
args.dropout,
|
||||||
|
args.activation_dropout,
|
||||||
|
args.subln,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_self_attention(self, embed_dim, args):
|
||||||
|
return MultiheadAttention(
|
||||||
|
args,
|
||||||
|
embed_dim,
|
||||||
|
args.encoder_attention_heads,
|
||||||
|
dropout=args.attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
subln=args.subln,
|
||||||
|
)
|
||||||
|
|
||||||
|
def residual_connection(self, x, residual):
|
||||||
|
return residual * self.alpha + x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
encoder_padding_mask,
|
||||||
|
attn_mask=None,
|
||||||
|
rel_pos=None
|
||||||
|
):
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x, _ = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=encoder_padding_mask,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
rel_pos=rel_pos,
|
||||||
|
)
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.residual_connection(x, residual)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if not self.is_moe_layer:
|
||||||
|
x = self.ffn(x)
|
||||||
|
l_aux = None
|
||||||
|
else:
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
x, l_aux = self.moe_layer(x)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
if self.drop_path is not None:
|
||||||
|
x = self.drop_path(x)
|
||||||
|
|
||||||
|
x = self.residual_connection(x, residual)
|
||||||
|
if not self.normalize_before:
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
return x, l_aux
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
embed_tokens=None,
|
||||||
|
embed_positions=None,
|
||||||
|
output_projection=None,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.args = args
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||||
|
|
||||||
|
embed_dim = args.encoder_embed_dim
|
||||||
|
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
||||||
|
|
||||||
|
self.embed_tokens = embed_tokens
|
||||||
|
self.embed_positions = embed_positions
|
||||||
|
|
||||||
|
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
|
||||||
|
self.output_projection = self.build_output_projection(args)
|
||||||
|
else:
|
||||||
|
self.output_projection = output_projection
|
||||||
|
|
||||||
|
if args.layernorm_embedding:
|
||||||
|
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1)
|
||||||
|
else:
|
||||||
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
moe_freq = args.moe_freq
|
||||||
|
for i in range(args.encoder_layers):
|
||||||
|
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
||||||
|
self.layers.append(
|
||||||
|
self.build_encoder_layer(
|
||||||
|
args,
|
||||||
|
depth=i,
|
||||||
|
is_moe_layer=is_moe_layer,
|
||||||
|
is_encoder_decoder=is_encoder_decoder
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.num_layers = len(self.layers)
|
||||||
|
|
||||||
|
if args.encoder_normalize_before:
|
||||||
|
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim))
|
||||||
|
else:
|
||||||
|
self.layer_norm = None
|
||||||
|
|
||||||
|
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
|
||||||
|
self.relative_position = RelativePositionBias(
|
||||||
|
num_buckets=args.rel_pos_buckets,
|
||||||
|
max_distance=args.max_rel_pos,
|
||||||
|
n_heads=args.encoder_attention_heads,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.relative_position = None
|
||||||
|
|
||||||
|
if args.bert_init:
|
||||||
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
|
if args.deepnorm:
|
||||||
|
if is_encoder_decoder:
|
||||||
|
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
|
||||||
|
else:
|
||||||
|
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
||||||
|
for name, p in self.named_parameters():
|
||||||
|
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||||
|
p.data.div_(init_scale)
|
||||||
|
|
||||||
|
if args.subln:
|
||||||
|
if is_encoder_decoder:
|
||||||
|
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
|
||||||
|
else:
|
||||||
|
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
||||||
|
for name, p in self.named_parameters():
|
||||||
|
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||||
|
p.data.mul_(init_scale)
|
||||||
|
|
||||||
|
def build_output_projection(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
):
|
||||||
|
if args.share_encoder_input_output_embed:
|
||||||
|
assert args.encoder_embedding_type == 'language'
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
self.embed_tokens.weight.shape[1],
|
||||||
|
self.embed_tokens.weight.shape[0],
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
output_projection.weight = self.embed_tokens.weight
|
||||||
|
else:
|
||||||
|
output_projection = torch.nn.Linear(
|
||||||
|
args.encoder_embed_dim, args.vocab_size, bias=False
|
||||||
|
)
|
||||||
|
torch.nn.init.normal_(
|
||||||
|
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5
|
||||||
|
)
|
||||||
|
return output_projection
|
||||||
|
|
||||||
|
def build_encoder_layer(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
depth,
|
||||||
|
is_moe_layer=False,
|
||||||
|
is_encoder_decoder=False
|
||||||
|
):
|
||||||
|
layer = EncoderLayer(
|
||||||
|
args,
|
||||||
|
depth,
|
||||||
|
is_moe_layer=is_moe_layer,
|
||||||
|
is_encoder_decoder=is_encoder_decoder
|
||||||
|
)
|
||||||
|
if args.checkpoint_activations:
|
||||||
|
layer = checkpoint_wrapper(layer)
|
||||||
|
if args.fsdp:
|
||||||
|
layer = wrap(layer)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def forward_embedding(
|
||||||
|
self,
|
||||||
|
src_tokens,
|
||||||
|
token_embedding=None,
|
||||||
|
):
|
||||||
|
if token_embedding is None:
|
||||||
|
token_embedding = self.embed_tokens(src_tokens)
|
||||||
|
x = embed = self.embed_scale * token_embedding
|
||||||
|
if self.embed_positions is not None:
|
||||||
|
if src_tokens is not None:
|
||||||
|
x = embed + self.embed_positions(src_tokens)
|
||||||
|
else:
|
||||||
|
x = embed + self.embed_positions(x)
|
||||||
|
if self.layernorm_embedding is not None:
|
||||||
|
x = self.layernorm_embedding(x)
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
return x, embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src_tokens,
|
||||||
|
encoder_padding_mask=None,
|
||||||
|
return_all_hiddens=False,
|
||||||
|
token_embeddings=None,
|
||||||
|
multiway_split_position=None,
|
||||||
|
features_only=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
assert src_tokens is not None or token_embeddings is not None
|
||||||
|
|
||||||
|
if encoder_padding_mask is None:
|
||||||
|
if src_tokens is not None:
|
||||||
|
encoder_padding_mask = torch.zeros_like(
|
||||||
|
src_tokens,
|
||||||
|
device=src_tokens.device
|
||||||
|
).bool()
|
||||||
|
else:
|
||||||
|
encoder_padding_mask = torch.zeros(
|
||||||
|
[token_embeddings.size(0), token_embeddings.size(1)],
|
||||||
|
device=token_embeddings.device
|
||||||
|
).bool()
|
||||||
|
|
||||||
|
if multiway_split_position is not None:
|
||||||
|
assert self.args.multiway
|
||||||
|
self.apply(set_split_position(multiway_split_position))
|
||||||
|
|
||||||
|
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
||||||
|
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||||
|
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
encoder_states = []
|
||||||
|
|
||||||
|
if return_all_hiddens:
|
||||||
|
encoder_states.append(x)
|
||||||
|
|
||||||
|
rel_pos_bias = None
|
||||||
|
if self.relative_position is not None:
|
||||||
|
rel_pos_bias = self.relative_position(
|
||||||
|
batch_size=x.size(1),
|
||||||
|
qlen=x.size(0),
|
||||||
|
klen=x.size(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
l_aux = []
|
||||||
|
for layer in self.layers:
|
||||||
|
x, l_aux_i = layer(
|
||||||
|
x, encoder_padding_mask=encoder_padding_mask,
|
||||||
|
rel_pos=rel_pos_bias
|
||||||
|
)
|
||||||
|
if return_all_hiddens:
|
||||||
|
assert encoder_states is not None
|
||||||
|
encoder_states.append(x)
|
||||||
|
l_aux.append(l_aux_i)
|
||||||
|
|
||||||
|
if self.layer_norm is not None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
if not features_only and self.output_projection is not None:
|
||||||
|
x = self.output_projection(x)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"encoder_out": x,
|
||||||
|
"encoder_embedding": encoder_embedding,
|
||||||
|
"encoder_padding_mask": encoder_padding_mask,
|
||||||
|
"encoder_states": encoder_states,
|
||||||
|
"l_aux": l_aux,
|
||||||
|
}
|
61
torchscale/architecture/encoder_decoder.py
Normal file
61
torchscale/architecture/encoder_decoder.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
from torchscale.architecture.decoder import Decoder
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
encoder_embed_tokens=None,
|
||||||
|
encoder_embed_positions=None,
|
||||||
|
decoder_embed_tokens=None,
|
||||||
|
decoder_embed_positions=None,
|
||||||
|
output_projection=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
if args.share_all_embeddings:
|
||||||
|
args.share_decoder_input_output_embed = True
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
args,
|
||||||
|
encoder_embed_tokens,
|
||||||
|
encoder_embed_positions,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.share_all_embeddings and decoder_embed_tokens is None:
|
||||||
|
decoder_embed_tokens = self.encoder.embed_tokens
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
args,
|
||||||
|
decoder_embed_tokens,
|
||||||
|
decoder_embed_positions,
|
||||||
|
output_projection,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src_tokens,
|
||||||
|
prev_output_tokens,
|
||||||
|
return_all_hiddens=False,
|
||||||
|
features_only=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
encoder_out = self.encoder(
|
||||||
|
src_tokens,
|
||||||
|
return_all_hiddens=return_all_hiddens
|
||||||
|
)
|
||||||
|
decoder_out = self.decoder(
|
||||||
|
prev_output_tokens,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
features_only=features_only,
|
||||||
|
return_all_hiddens=return_all_hiddens,
|
||||||
|
)
|
||||||
|
return decoder_out
|
30
torchscale/architecture/utils.py
Normal file
30
torchscale/architecture/utils.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
|
from torchscale.component.multiway_network import MultiwayNetwork
|
||||||
|
|
||||||
|
|
||||||
|
def init_bert_params(module):
|
||||||
|
|
||||||
|
def normal_(data):
|
||||||
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
if isinstance(module, MultiheadAttention):
|
||||||
|
if isinstance(module.q_proj, MultiwayNetwork):
|
||||||
|
normal_(module.q_proj.A.weight.data)
|
||||||
|
normal_(module.q_proj.B.weight.data)
|
||||||
|
normal_(module.k_proj.A.weight.data)
|
||||||
|
normal_(module.k_proj.B.weight.data)
|
||||||
|
normal_(module.v_proj.A.weight.data)
|
||||||
|
normal_(module.v_proj.B.weight.data)
|
||||||
|
else:
|
||||||
|
normal_(module.q_proj.weight.data)
|
||||||
|
normal_(module.k_proj.weight.data)
|
||||||
|
normal_(module.v_proj.weight.data)
|
0
torchscale/component/__init__.py
Normal file
0
torchscale/component/__init__.py
Normal file
16
torchscale/component/droppath.py
Normal file
16
torchscale/component/droppath.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from timm.models.layers import drop_path
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return 'p={}'.format(self.drop_prob)
|
120
torchscale/component/embedding.py
Normal file
120
torchscale/component/embedding.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class VisionLanguageEmbedding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_embed,
|
||||||
|
vision_embed
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.text_embed = text_embed
|
||||||
|
self.vision_embed = vision_embed
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
textual_tokens,
|
||||||
|
visual_tokens,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if textual_tokens is None:
|
||||||
|
return self.vision_embed(visual_tokens)
|
||||||
|
|
||||||
|
if visual_tokens is None:
|
||||||
|
return self.text_embed(textual_tokens)
|
||||||
|
|
||||||
|
x1 = self.vision_embed(visual_tokens)
|
||||||
|
x2 = self.text_embed(textual_tokens)
|
||||||
|
|
||||||
|
return torch.cat([x1, x2], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
class VisionEmbedding(nn.Module):
|
||||||
|
""" Image to Patch Embedding
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=768,
|
||||||
|
contain_mask_token=False,
|
||||||
|
prepend_cls_token=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
img_size = (img_size, img_size)
|
||||||
|
patch_size = (patch_size, patch_size)
|
||||||
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||||
|
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_patches = num_patches
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
if contain_mask_token:
|
||||||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.mask_token = None
|
||||||
|
|
||||||
|
if prepend_cls_token:
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.cls_token = None
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
masked_position=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||||
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||||
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, seq_len, _ = x.size()
|
||||||
|
|
||||||
|
if masked_position is not None:
|
||||||
|
assert self.mask_token is not None
|
||||||
|
mask_token = self.mask_token.expand(batch_size, seq_len, -1)
|
||||||
|
w = masked_position.unsqueeze(-1).type_as(mask_token)
|
||||||
|
x = x * (1 - w) + mask_token * w
|
||||||
|
|
||||||
|
if self.cls_token is not None:
|
||||||
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding(nn.Embedding):
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
|
||||||
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEmbedding(nn.Embedding):
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
positions=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if positions is None:
|
||||||
|
# being consistent with Fairseq, which starts from 2.
|
||||||
|
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0)
|
||||||
|
return F.embedding(
|
||||||
|
positions,
|
||||||
|
self.weight,
|
||||||
|
self.padding_idx,
|
||||||
|
self.max_norm,
|
||||||
|
self.norm_type,
|
||||||
|
self.scale_grad_by_freq,
|
||||||
|
self.sparse,
|
||||||
|
)
|
119
torchscale/component/feedforward_network.py
Normal file
119
torchscale/component/feedforward_network.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
class set_torch_seed(object):
|
||||||
|
def __init__(self, seed):
|
||||||
|
assert isinstance(seed, int)
|
||||||
|
self.rng_state = self.get_rng_state()
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
def get_rng_state(self):
|
||||||
|
state = {"torch_rng_state": torch.get_rng_state()}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def set_rng_state(self, state):
|
||||||
|
torch.set_rng_state(state["torch_rng_state"])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self.set_rng_state(self.rng_state)
|
||||||
|
|
||||||
|
|
||||||
|
def make_experts(args, embed_dim, expert_ffn_dim):
|
||||||
|
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
|
||||||
|
expert_list = []
|
||||||
|
ddp_rank = args.ddp_rank
|
||||||
|
start_seed = torch.randint(1000000, (1,)).item()
|
||||||
|
# at least as many experts than gpus
|
||||||
|
if args.moe_expert_count >= world_size:
|
||||||
|
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}'
|
||||||
|
local_moe_expert_count = args.moe_expert_count // world_size
|
||||||
|
for i in range(local_moe_expert_count):
|
||||||
|
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
||||||
|
expert_list.append(
|
||||||
|
FeedForwardNetwork(
|
||||||
|
embed_dim,
|
||||||
|
expert_ffn_dim,
|
||||||
|
args.activation_fn,
|
||||||
|
args.dropout,
|
||||||
|
args.activation_dropout,
|
||||||
|
args.subln
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}'
|
||||||
|
|
||||||
|
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
||||||
|
expert_list.append(
|
||||||
|
FeedForwardNetwork(
|
||||||
|
embed_dim,
|
||||||
|
expert_ffn_dim,
|
||||||
|
args.activation_fn,
|
||||||
|
args.dropout,
|
||||||
|
args.activation_dropout,
|
||||||
|
args.subln
|
||||||
|
)
|
||||||
|
)
|
||||||
|
experts = nn.ModuleList(expert_list)
|
||||||
|
return experts
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation):
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
return lambda x: F.gelu(x.float()).type_as(x)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForwardNetwork(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
ffn_dim,
|
||||||
|
activation_fn,
|
||||||
|
dropout,
|
||||||
|
activation_dropout,
|
||||||
|
subln=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||||
|
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True)
|
||||||
|
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||||
|
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
||||||
|
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
||||||
|
self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.fc1.reset_parameters()
|
||||||
|
self.fc2.reset_parameters()
|
||||||
|
if self.ffn_layernorm is not None:
|
||||||
|
self.ffn_layernorm.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.reshape(-1, x.size(-1))
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation_fn(x)
|
||||||
|
x = self.activation_dropout_module(x)
|
||||||
|
if self.ffn_layernorm is not None:
|
||||||
|
x = self.ffn_layernorm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = x.view(x_shape)
|
||||||
|
x = self.dropout_module(x)
|
||||||
|
return x
|
117
torchscale/component/multihead_attention.py
Normal file
117
torchscale/component/multihead_attention.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from .multiway_network import MultiwayWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.0,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
subln=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.self_attention = self_attention
|
||||||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
assert self.self_attention ^ self.encoder_decoder_attention
|
||||||
|
|
||||||
|
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
|
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
|
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
|
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
|
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None
|
||||||
|
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
incremental_state=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
attn_mask=None,
|
||||||
|
rel_pos=None,
|
||||||
|
):
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
src_len = tgt_len
|
||||||
|
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
||||||
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
|
|
||||||
|
src_len, key_bsz, _ = key.size()
|
||||||
|
assert key_bsz == bsz, f"{query.size(), key.size()}"
|
||||||
|
assert value is not None
|
||||||
|
assert src_len, bsz == value.shape[:2]
|
||||||
|
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
q *= self.scaling
|
||||||
|
|
||||||
|
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||||
|
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||||
|
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
if "prev_key" in incremental_state:
|
||||||
|
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
|
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
src_len = k.size(1)
|
||||||
|
|
||||||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_weights = torch.nan_to_num(attn_weights)
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if rel_pos is not None:
|
||||||
|
rel_pos = rel_pos.view(attn_weights.size())
|
||||||
|
attn_weights = attn_weights + rel_pos
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
|
||||||
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
|
|
||||||
|
if self.inner_attn_ln is not None:
|
||||||
|
attn = self.inner_attn_ln(attn)
|
||||||
|
|
||||||
|
attn = self.out_proj(attn)
|
||||||
|
attn_weights = attn_weights.view(
|
||||||
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
).transpose(1, 0)
|
||||||
|
|
||||||
|
return attn, attn_weights
|
39
torchscale/component/multiway_network.py
Normal file
39
torchscale/component/multiway_network.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def MultiwayWrapper(args, module, dim=0):
|
||||||
|
if args.multiway:
|
||||||
|
return MultiwayNetwork(module, dim=dim)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def set_split_position(position):
|
||||||
|
|
||||||
|
def apply_fn(module):
|
||||||
|
if hasattr(module, 'split_position'):
|
||||||
|
module.split_position = position
|
||||||
|
|
||||||
|
return apply_fn
|
||||||
|
|
||||||
|
|
||||||
|
class MultiwayNetwork(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, module, dim=0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.A = module
|
||||||
|
self.B = copy.deepcopy(module)
|
||||||
|
self.B.reset_parameters()
|
||||||
|
self.split_position = -1
|
||||||
|
|
||||||
|
def forward(self, x, **kwargs):
|
||||||
|
if self.split_position == -1:
|
||||||
|
return self.A(x, **kwargs)
|
||||||
|
if self.split_position == 0:
|
||||||
|
return self.B(x, **kwargs)
|
||||||
|
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim)
|
||||||
|
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
||||||
|
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
||||||
|
return torch.cat([y1, y2], dim=self.dim)
|
79
torchscale/component/relative_position_bias.py
Normal file
79
torchscale/component/relative_position_bias.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionBias(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128,
|
||||||
|
n_heads=12
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.bidirectional = bidirectional
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(
|
||||||
|
relative_position,
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128
|
||||||
|
):
|
||||||
|
ret = 0
|
||||||
|
n = -relative_position
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
ret += (n < 0).to(torch.long) * num_buckets
|
||||||
|
n = torch.abs(n)
|
||||||
|
else:
|
||||||
|
n = torch.max(n, torch.zeros_like(n))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = n < max_exact
|
||||||
|
|
||||||
|
val_if_large = max_exact + (
|
||||||
|
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||||
|
|
||||||
|
ret += torch.where(is_small, n, val_if_large)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def compute_bias(
|
||||||
|
self,
|
||||||
|
qlen,
|
||||||
|
klen,
|
||||||
|
step=None
|
||||||
|
):
|
||||||
|
step = 0 if step is None else step
|
||||||
|
context_position = torch.arange(step, step + qlen, dtype=torch.long,
|
||||||
|
device=self.relative_attention_bias.weight.device)[:, None]
|
||||||
|
memory_position = torch.arange(klen, dtype=torch.long,
|
||||||
|
device=self.relative_attention_bias.weight.device)[None, :]
|
||||||
|
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||||
|
|
||||||
|
rp_bucket = self._relative_position_bucket(
|
||||||
|
relative_position, # shape (qlen, klen)
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
)
|
||||||
|
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
|
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
qlen,
|
||||||
|
klen,
|
||||||
|
step=None
|
||||||
|
):
|
||||||
|
# shape (batch * num_heads, qlen, klen)
|
||||||
|
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
|
0
torchscale/component/xmoe/__init__.py
Normal file
0
torchscale/component/xmoe/__init__.py
Normal file
310
torchscale/component/xmoe/moe_layer.py
Normal file
310
torchscale/component/xmoe/moe_layer.py
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# NOTE: This is a mirror of the code in
|
||||||
|
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Tuple, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Module, ModuleList
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fairseq.modules.moe import MOELayer
|
||||||
|
has_fairseq = True
|
||||||
|
Base = MOELayer
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
Base = Module
|
||||||
|
has_fairseq = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# To enable Tutel MoE optimizations:
|
||||||
|
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
|
||||||
|
from tutel import moe as tutel_moe
|
||||||
|
|
||||||
|
has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||||
|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||||
|
|
||||||
|
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||||
|
class _AllToAll(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
|
||||||
|
ctx.group = group
|
||||||
|
input = input.contiguous()
|
||||||
|
output = torch.empty_like(input)
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
dist.all_to_all_single(output, input, group=group)
|
||||||
|
else:
|
||||||
|
assert group is None
|
||||||
|
output = input
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||||
|
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||||
|
|
||||||
|
|
||||||
|
def _find_my_group_index(grouped_ranks):
|
||||||
|
my_rank = dist.get_rank()
|
||||||
|
for i, group in enumerate(grouped_ranks):
|
||||||
|
if my_rank in group:
|
||||||
|
return i
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_group(moe_expert_count):
|
||||||
|
if dist.is_initialized():
|
||||||
|
if not hasattr(get_moe_group, "_moe_groups"):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
if world_size <= moe_expert_count:
|
||||||
|
assert moe_expert_count % world_size == 0
|
||||||
|
moe_groups = [[i] for i in range(world_size)]
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert world_size % moe_expert_count == 0
|
||||||
|
ranks_per_group = world_size // moe_expert_count
|
||||||
|
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||||
|
for i in range(moe_expert_count)]
|
||||||
|
|
||||||
|
get_moe_group._moe_group_idx = moe_groups
|
||||||
|
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||||
|
|
||||||
|
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
|
||||||
|
return get_moe_group._moe_groups[my_group_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def get_all2all_group(moe_expert_count):
|
||||||
|
if dist.is_initialized():
|
||||||
|
if not hasattr(get_all2all_group, "_all2all_groups"):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
# more experts than world size
|
||||||
|
if world_size <= moe_expert_count:
|
||||||
|
assert moe_expert_count % world_size == 0
|
||||||
|
all2all_groups = [[i for i in range(world_size)]]
|
||||||
|
|
||||||
|
# larger world than num experts
|
||||||
|
else:
|
||||||
|
assert world_size % moe_expert_count == 0
|
||||||
|
ranks_per_group = world_size // moe_expert_count
|
||||||
|
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||||
|
for i in range(ranks_per_group)]
|
||||||
|
|
||||||
|
get_all2all_group._all2all_group_idx = all2all_groups
|
||||||
|
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups]
|
||||||
|
|
||||||
|
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||||
|
return get_all2all_group._all2all_groups[my_group_idx]
|
||||||
|
|
||||||
|
|
||||||
|
class MOELayer(Base):
|
||||||
|
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||||
|
::
|
||||||
|
|
||||||
|
gate = Top2Gate(model_dim, num_experts)
|
||||||
|
moe = MOELayer(gate, expert)
|
||||||
|
output = moe(input)
|
||||||
|
l_aux = moe.l_aux
|
||||||
|
|
||||||
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gate (torch.nn.Module):
|
||||||
|
gate network
|
||||||
|
expert (torch.nn.Module):
|
||||||
|
expert network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gate,
|
||||||
|
experts,
|
||||||
|
args
|
||||||
|
):
|
||||||
|
if has_fairseq:
|
||||||
|
super(Base, self).__init__()
|
||||||
|
else:
|
||||||
|
super().__init__()
|
||||||
|
self.gate = gate
|
||||||
|
if type(experts) == ModuleList:
|
||||||
|
self.experts = cast(ModuleList, experts)
|
||||||
|
else:
|
||||||
|
self.experts = ModuleList([experts])
|
||||||
|
self.expert_group = get_moe_group(args.moe_expert_count)
|
||||||
|
self.all2all_group = get_all2all_group(args.moe_expert_count)
|
||||||
|
self.world_size = dist.get_world_size(group=self.expert_group)
|
||||||
|
self.all2all_size = dist.get_world_size(group=self.all2all_group)
|
||||||
|
for p in experts.parameters():
|
||||||
|
p.expert = True # type: ignore
|
||||||
|
self.num_local_experts = len(self.experts)
|
||||||
|
self.args = args
|
||||||
|
self.in_generation = False
|
||||||
|
self.a2a_cuda_event_intervals = []
|
||||||
|
self.a2a_cpu_time_ms = 0.0
|
||||||
|
|
||||||
|
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
|
||||||
|
assert len(input) == 1, "only single input Tensor supported"
|
||||||
|
input = input[0]
|
||||||
|
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
|
||||||
|
if input_padding_mask is not None:
|
||||||
|
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken"
|
||||||
|
assert input_padding_mask.shape[0] == input.shape[0]
|
||||||
|
assert input_padding_mask.shape[1] == input.shape[1]
|
||||||
|
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
|
||||||
|
|
||||||
|
# Implement Algorithm 2 from GShard paper.
|
||||||
|
d_model = input.shape[2]
|
||||||
|
# Pad to expected batch size
|
||||||
|
input_shape = list(input.shape)
|
||||||
|
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0)
|
||||||
|
# This indicates that --batch-size or --max-sentences is not specified
|
||||||
|
if expected_bsz is None:
|
||||||
|
expected_bsz = 0
|
||||||
|
# Note: Padding is not necessary at generation time at present
|
||||||
|
# because all DDP workers process the same batch. Also, batch size at generation time
|
||||||
|
# can be different from that present in the checkpoint state
|
||||||
|
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz:
|
||||||
|
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})")
|
||||||
|
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
|
||||||
|
padded_input = torch.zeros(
|
||||||
|
(expected_bsz, input_shape[1], input_shape[2]),
|
||||||
|
dtype=input.dtype, layout=input.layout, device=input.device)
|
||||||
|
padded_input[:input_shape[0], :, :] = input
|
||||||
|
input = padded_input
|
||||||
|
|
||||||
|
padded_input_padding_mask = torch.ones(
|
||||||
|
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device
|
||||||
|
)
|
||||||
|
if input_padding_mask is not None:
|
||||||
|
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask
|
||||||
|
else:
|
||||||
|
padded_input_padding_mask[:input_shape[0], :] = False
|
||||||
|
input_padding_mask = padded_input_padding_mask
|
||||||
|
|
||||||
|
# Reshape into S tokens by dropping sequence dimension.
|
||||||
|
reshaped_input = input.reshape(-1, d_model)
|
||||||
|
reshaped_input_shape = reshaped_input.shape
|
||||||
|
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None
|
||||||
|
|
||||||
|
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||||
|
# Pro of --max-tokens: more flexible for MT variable sequence lengths
|
||||||
|
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
|
||||||
|
if expected_bsz == 0:
|
||||||
|
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device)
|
||||||
|
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
|
||||||
|
expected_dim = int(expected_dim.item())
|
||||||
|
padded_input = torch.zeros(
|
||||||
|
(expected_dim, reshaped_input_shape[1]),
|
||||||
|
dtype=input.dtype, layout=input.layout, device=input.device)
|
||||||
|
padded_input[:reshaped_input_shape[0], :] = reshaped_input
|
||||||
|
reshaped_input = padded_input
|
||||||
|
|
||||||
|
padded_input_padding_mask = torch.ones(
|
||||||
|
(expected_dim,), dtype=torch.bool, device=padded_input.device
|
||||||
|
)
|
||||||
|
if reshaped_input_padding_mask is not None:
|
||||||
|
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask
|
||||||
|
else:
|
||||||
|
padded_input_padding_mask[:reshaped_input_shape[0]] = False
|
||||||
|
reshaped_input_padding_mask = padded_input_padding_mask
|
||||||
|
|
||||||
|
if has_tutel:
|
||||||
|
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask)
|
||||||
|
S, M = reshaped_input.size(0), reshaped_input.size(1)
|
||||||
|
|
||||||
|
if not hasattr(self, '_tutel_dispatcher'):
|
||||||
|
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
|
||||||
|
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
|
||||||
|
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
|
||||||
|
else:
|
||||||
|
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask)
|
||||||
|
|
||||||
|
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S
|
||||||
|
E, C, S = dispatch_mask.size()
|
||||||
|
M = reshaped_input.size(1)
|
||||||
|
assert reshaped_input.size() == (S, M)
|
||||||
|
# einsum("sec,sm->ecm")
|
||||||
|
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M
|
||||||
|
|
||||||
|
if self.all2all_size > 1:
|
||||||
|
dispatched_input = self.all_to_all_wrapper(dispatched_input)
|
||||||
|
|
||||||
|
# Re-shape after all-to-all: ecm -> gecm
|
||||||
|
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model)
|
||||||
|
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
|
||||||
|
expert_outputs = []
|
||||||
|
for chunk, expert in zip(chunks, self.experts):
|
||||||
|
expert_outputs += [expert(chunk)]
|
||||||
|
expert_output = torch.cat(expert_outputs, dim=1)
|
||||||
|
|
||||||
|
if self.all2all_size > 1:
|
||||||
|
expert_output = self.all_to_all_wrapper(expert_output)
|
||||||
|
|
||||||
|
# Re-shape back: gecm -> ecm
|
||||||
|
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model)
|
||||||
|
|
||||||
|
if has_tutel:
|
||||||
|
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M))
|
||||||
|
else:
|
||||||
|
# einsum("sec,ecm->sm")
|
||||||
|
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M))
|
||||||
|
|
||||||
|
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||||
|
combined_output = combined_output[:reshaped_input_shape[0], :]
|
||||||
|
combined_output = combined_output.reshape(input.shape)
|
||||||
|
combined_output = combined_output[:input_shape[0], :, :]
|
||||||
|
|
||||||
|
self.record_all_to_all_stats()
|
||||||
|
|
||||||
|
return combined_output, l_aux
|
||||||
|
|
||||||
|
def prepare_for_inference_(self):
|
||||||
|
self.in_generation = True
|
||||||
|
|
||||||
|
def all_to_all_wrapper(self, input: Tensor):
|
||||||
|
dummy_a2a = getattr(self.args, 'dummy_a2a', False)
|
||||||
|
if dummy_a2a:
|
||||||
|
input = input.contiguous()
|
||||||
|
output = input.detach().clone()
|
||||||
|
return input
|
||||||
|
# always record times, since it is not a lot of overhead
|
||||||
|
# if we do not log it we simply clear it off in record_all_to_all_stats
|
||||||
|
cuda_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
cuda_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
cpu_start = time.time() * 1000
|
||||||
|
cuda_start.record()
|
||||||
|
output = _AllToAll.apply(self.all2all_group, input)
|
||||||
|
cuda_end.record()
|
||||||
|
cpu_end = time.time() * 1000
|
||||||
|
self.a2a_cpu_time_ms += (cpu_end - cpu_start)
|
||||||
|
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def record_all_to_all_stats(self):
|
||||||
|
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
|
||||||
|
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False)
|
||||||
|
if record_a2a_perf_stats:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
|
||||||
|
a2a_cuda_time_ms = 0.0
|
||||||
|
for ev_start, ev_end in self.a2a_cuda_event_intervals:
|
||||||
|
a2a_cuda_time_ms += ev_start.elapsed_time(ev_end)
|
||||||
|
self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms
|
||||||
|
# reset stats
|
||||||
|
self.a2a_cpu_time_ms = 0.0
|
||||||
|
self.a2a_cuda_event_intervals = []
|
459
torchscale/component/xmoe/routing.py
Normal file
459
torchscale/component/xmoe/routing.py
Normal file
|
@ -0,0 +1,459 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf
|
||||||
|
# Code is inspired by Top2GatingOnLogits from lingvo:
|
||||||
|
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
|
||||||
|
|
||||||
|
# NOTE: This is a mirror of the code in
|
||||||
|
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
||||||
|
|
||||||
|
from typing import Callable, Dict, Tuple, Optional
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .moe_layer import has_tutel, fused_cumsum_sub_one
|
||||||
|
|
||||||
|
# use a fixed temperature to compute balance loss
|
||||||
|
TEMPERATURE_FOR_L_UAX = 0.07
|
||||||
|
|
||||||
|
# maximum capacity of 1 expert as a fraction of number of tokens in the batch
|
||||||
|
# Note: setting this to 1.0 causes inference to significantly slow down
|
||||||
|
EVAL_CAPACITY_TOKEN_FRACTION = 0.25
|
||||||
|
|
||||||
|
# logging
|
||||||
|
SAMPLE_FRACTION = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
def top1gating(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
input_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_fp32=False,
|
||||||
|
capacity_factor=1.0,
|
||||||
|
eval_mode=False,
|
||||||
|
moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
|
||||||
|
use_xmoe=False,
|
||||||
|
gate_obj=None,
|
||||||
|
) -> Tuple[Tensor, Tensor, Tensor, Dict]:
|
||||||
|
"""Implements Top2Gating on logits."""
|
||||||
|
metadata = {}
|
||||||
|
if use_fp32:
|
||||||
|
orig_dtype = logits.dtype
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
gates = F.softmax(logits, dim=1)
|
||||||
|
metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
|
||||||
|
|
||||||
|
# gates has shape of SE
|
||||||
|
num_tokens = gates.shape[0]
|
||||||
|
num_experts = gates.shape[1]
|
||||||
|
if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
|
||||||
|
capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
|
||||||
|
else:
|
||||||
|
# capacity = capacity_factor * S/E
|
||||||
|
capacity = int(capacity_factor * math.ceil(num_tokens / num_experts))
|
||||||
|
|
||||||
|
# Create a mask for 1st's expert per token
|
||||||
|
indices1_s = torch.argmax(gates, dim=1)
|
||||||
|
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
|
||||||
|
if input_mask is not None and input_mask.any():
|
||||||
|
nonpadding = ~ input_mask
|
||||||
|
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||||
|
|
||||||
|
# for logging (percent of tokens routed to each expert)
|
||||||
|
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||||
|
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||||
|
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||||
|
|
||||||
|
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||||
|
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||||
|
metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
|
||||||
|
|
||||||
|
gates1_s = (gates * mask1).sum(dim=1)
|
||||||
|
|
||||||
|
# Compute locations in capacity buffer
|
||||||
|
locations1 = fused_cumsum_sub_one(mask1)
|
||||||
|
|
||||||
|
# Compute l_aux
|
||||||
|
me = torch.mean(gates, dim=0)
|
||||||
|
ce = torch.mean(mask1.to(gates.dtype), dim=0)
|
||||||
|
|
||||||
|
l_aux = torch.mean(me * ce)
|
||||||
|
l_aux = l_aux * num_experts * num_experts
|
||||||
|
|
||||||
|
if has_tutel:
|
||||||
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ]
|
||||||
|
|
||||||
|
# Remove locations outside capacity from mask
|
||||||
|
mask1 = mask1 * torch.lt(locations1, capacity)
|
||||||
|
# Store the capacity location for each token
|
||||||
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
|
||||||
|
# Calculate combine_weights and dispatch_mask
|
||||||
|
gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
|
||||||
|
# locations1_sc = num_tokens * capacity
|
||||||
|
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
||||||
|
combine1_sec = torch.bmm(
|
||||||
|
# einsum("se,sc->sec")
|
||||||
|
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
||||||
|
)
|
||||||
|
dispatch_mask = combine1_sec.bool()
|
||||||
|
if use_fp32:
|
||||||
|
return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata
|
||||||
|
else:
|
||||||
|
return l_aux, combine1_sec, dispatch_mask, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class Top1Gate(torch.nn.Module):
|
||||||
|
"""Gate module which implements Top2Gating as described in Gshard_.
|
||||||
|
::
|
||||||
|
|
||||||
|
gate = Top2Gate(model_dim, num_experts)
|
||||||
|
l_aux, combine_weights, dispatch_mask = gate(input)
|
||||||
|
|
||||||
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dim (int):
|
||||||
|
size of model embedding dimension
|
||||||
|
num_experts (ints):
|
||||||
|
number of experts in model
|
||||||
|
"""
|
||||||
|
|
||||||
|
wg: torch.nn.Linear
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dim: int,
|
||||||
|
num_experts: int,
|
||||||
|
use_fp32=False,
|
||||||
|
input_noise_type=None,
|
||||||
|
capacity_factor=1.0,
|
||||||
|
moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
|
||||||
|
use_xmoe=False,
|
||||||
|
) -> None:
|
||||||
|
# TODO: merge this to top2gate.py
|
||||||
|
#
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not use_xmoe:
|
||||||
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||||
|
else:
|
||||||
|
self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
|
||||||
|
wg = torch.empty(num_experts, 16)
|
||||||
|
torch.nn.init.orthogonal_(wg, gain=0.32)
|
||||||
|
self.register_parameter("wg", torch.nn.Parameter(wg))
|
||||||
|
|
||||||
|
self.use_xmoe = use_xmoe
|
||||||
|
self.use_fp32 = use_fp32
|
||||||
|
self.input_noise_type = input_noise_type
|
||||||
|
self.capacity_factor = capacity_factor
|
||||||
|
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
|
||||||
|
|
||||||
|
def forward(self, input, mask=None): # type: ignore
|
||||||
|
if self.use_xmoe:
|
||||||
|
input = self.wg_reduction(input)
|
||||||
|
with torch.no_grad():
|
||||||
|
wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
|
||||||
|
self.wg.mul_(1.5 / wg_norm)
|
||||||
|
logits = self._cosine(input, self.wg)
|
||||||
|
logits = self._make_finite(logits)
|
||||||
|
else:
|
||||||
|
logits = self.wg(input)
|
||||||
|
|
||||||
|
return top1gating(
|
||||||
|
logits,
|
||||||
|
mask,
|
||||||
|
use_fp32=self.use_fp32,
|
||||||
|
capacity_factor=self.capacity_factor,
|
||||||
|
eval_mode=not self.training,
|
||||||
|
moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
|
||||||
|
use_xmoe=self.use_xmoe,
|
||||||
|
gate_obj=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_finite(self, scores):
|
||||||
|
ok = scores.isfinite()
|
||||||
|
if not ok.all():
|
||||||
|
# NaNs here can break the assignment algorithm
|
||||||
|
scores[~ok] = scores[ok].min()
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def _get_gating_temperature(self, eps=1e-4):
|
||||||
|
if self.gating_t.data.item() < eps:
|
||||||
|
return eps
|
||||||
|
return self.gating_t
|
||||||
|
|
||||||
|
def _cosine(self, mat1, mat2, eps=1e-4):
|
||||||
|
assert mat1.dim() == 2
|
||||||
|
assert mat2.dim() == 2
|
||||||
|
# mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
|
||||||
|
mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
|
||||||
|
return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
|
||||||
|
|
||||||
|
|
||||||
|
gumbel_map: Dict[torch.device, Callable] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||||
|
gumbel = gumbel_map.get(device)
|
||||||
|
if gumbel is None:
|
||||||
|
one = torch.tensor(1.0, device=device)
|
||||||
|
zero = torch.tensor(0.0, device=device)
|
||||||
|
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
|
||||||
|
gumbel_map[device] = gumbel
|
||||||
|
return gumbel(shape)
|
||||||
|
|
||||||
|
|
||||||
|
def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor:
|
||||||
|
if unsqueeze_indices:
|
||||||
|
indices = indices.unsqueeze(-1)
|
||||||
|
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
|
||||||
|
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype)
|
||||||
|
output.scatter_(
|
||||||
|
len(output.shape) - 1, indices, 1
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(probs):
|
||||||
|
logits = torch.distributions.utils.probs_to_logits(probs)
|
||||||
|
p_log_p = probs * logits
|
||||||
|
return -p_log_p.sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def top2gating(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
input_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_fp32=False,
|
||||||
|
second_expert_policy='sampling',
|
||||||
|
normalize_gate_prob_before_dropping=False,
|
||||||
|
eval_mode=False,
|
||||||
|
moe_eval_capacity_token_fraction=0.25,
|
||||||
|
batch_prioritized_routing=False,
|
||||||
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
|
"""Implements Top2Gating on logits."""
|
||||||
|
metadata = {}
|
||||||
|
if use_fp32:
|
||||||
|
orig_dtype = logits.dtype
|
||||||
|
logits = logits.float()
|
||||||
|
gates = F.softmax(logits, dim=1)
|
||||||
|
metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
|
||||||
|
# gates has shape of SE
|
||||||
|
num_tokens = gates.shape[0]
|
||||||
|
num_experts = gates.shape[1]
|
||||||
|
if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
|
||||||
|
capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
|
||||||
|
else:
|
||||||
|
# capacity = 2S/E
|
||||||
|
capacity = 2 * math.ceil(num_tokens / num_experts)
|
||||||
|
|
||||||
|
# Create a mask for 1st's expert per token
|
||||||
|
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
|
||||||
|
mask1 = one_hot(indices1_s, num_experts)
|
||||||
|
if second_expert_policy == 'sampling':
|
||||||
|
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||||
|
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||||
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||||
|
else:
|
||||||
|
logits_w_noise = logits
|
||||||
|
# Replace top-expert with min value
|
||||||
|
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
|
||||||
|
indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True)
|
||||||
|
mask2 = one_hot(indices2_s, num_experts)
|
||||||
|
gates1_s = (gates * mask1).sum(dim=1)
|
||||||
|
gates2_s = (gates * mask2).sum(dim=1)
|
||||||
|
|
||||||
|
if normalize_gate_prob_before_dropping:
|
||||||
|
# Normalize gate probabilities
|
||||||
|
denom_s = gates1_s + gates2_s
|
||||||
|
# Avoid divide-by-zero
|
||||||
|
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
||||||
|
gates1_s = gates1_s / denom_s
|
||||||
|
gates2_s = gates2_s / denom_s
|
||||||
|
|
||||||
|
if second_expert_policy == 'random':
|
||||||
|
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
|
||||||
|
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
|
||||||
|
|
||||||
|
# Compute locations in capacity buffer
|
||||||
|
if input_mask is not None and input_mask.any():
|
||||||
|
nonpadding = ~ input_mask
|
||||||
|
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||||
|
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||||
|
|
||||||
|
if batch_prioritized_routing:
|
||||||
|
# if batch_prioritized_routing:
|
||||||
|
importance_scores = -1 * gates.max(dim=1)[0]
|
||||||
|
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
|
||||||
|
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
|
||||||
|
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
|
||||||
|
|
||||||
|
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
|
||||||
|
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
|
||||||
|
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
|
||||||
|
|
||||||
|
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2
|
||||||
|
else:
|
||||||
|
locations1 = fused_cumsum_sub_one(mask1)
|
||||||
|
locations2 = fused_cumsum_sub_one(mask2)
|
||||||
|
# Update 2nd's location by accounting for locations of 1st
|
||||||
|
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# Compute l_aux
|
||||||
|
me = torch.mean(gates, dim=0)
|
||||||
|
ce = torch.mean(mask1.to(gates.dtype), dim=0)
|
||||||
|
l_aux = torch.mean(me * ce)
|
||||||
|
l_aux = l_aux * num_experts * num_experts
|
||||||
|
|
||||||
|
# for logging purposes
|
||||||
|
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
|
||||||
|
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
|
||||||
|
|
||||||
|
# Remove locations outside capacity from mask
|
||||||
|
mask1_, mask2_ = mask1, mask2
|
||||||
|
mask1 = mask1 * torch.lt(locations1, capacity)
|
||||||
|
mask2 = mask2 * torch.lt(locations2, capacity)
|
||||||
|
|
||||||
|
# for logging (percent of tokens routed to each expert)
|
||||||
|
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||||
|
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||||
|
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||||
|
|
||||||
|
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||||
|
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
|
||||||
|
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||||
|
|
||||||
|
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||||
|
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||||
|
metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
|
||||||
|
|
||||||
|
metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum()
|
||||||
|
metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum()
|
||||||
|
|
||||||
|
if not normalize_gate_prob_before_dropping:
|
||||||
|
# Normalize gate probabilities
|
||||||
|
gates1_s = (gates * mask1).sum(dim=1)
|
||||||
|
gates2_s = (gates * mask2).sum(dim=1)
|
||||||
|
denom_s = gates1_s + gates2_s
|
||||||
|
# Avoid divide-by-zero
|
||||||
|
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
||||||
|
gates1_s /= denom_s
|
||||||
|
gates2_s /= denom_s
|
||||||
|
|
||||||
|
if has_tutel:
|
||||||
|
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||||
|
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||||
|
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||||
|
|
||||||
|
# Store the capacity location for each token
|
||||||
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
||||||
|
|
||||||
|
# Calculate combine_weights and dispatch_mask
|
||||||
|
gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
|
||||||
|
gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se")
|
||||||
|
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
||||||
|
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
|
||||||
|
combine1_sec = torch.bmm(
|
||||||
|
# einsum("se,sc->sec")
|
||||||
|
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
||||||
|
)
|
||||||
|
combine2_sec = torch.bmm(
|
||||||
|
# einsum("se,sc->sec")
|
||||||
|
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1)
|
||||||
|
)
|
||||||
|
combine_weights = combine1_sec + combine2_sec
|
||||||
|
dispatch_mask = combine_weights.bool()
|
||||||
|
if use_fp32:
|
||||||
|
return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata
|
||||||
|
else:
|
||||||
|
return l_aux, combine_weights, dispatch_mask, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class Top2Gate(torch.nn.Module):
|
||||||
|
"""Gate module which implements Top2Gating as described in Gshard_.
|
||||||
|
::
|
||||||
|
|
||||||
|
gate = Top2Gate(model_dim, num_experts)
|
||||||
|
l_aux, combine_weights, dispatch_mask = gate(input)
|
||||||
|
|
||||||
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dim (int):
|
||||||
|
size of model embedding dimension
|
||||||
|
num_experts (ints):
|
||||||
|
number of experts in model
|
||||||
|
"""
|
||||||
|
|
||||||
|
wg: torch.nn.Linear
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dim: int,
|
||||||
|
num_experts: int,
|
||||||
|
use_fp32=False,
|
||||||
|
second_expert_policy='sampling',
|
||||||
|
normalize_gate_prob_before_dropping=False,
|
||||||
|
moe_eval_capacity_token_fraction=0.25,
|
||||||
|
batch_prioritized_routing=False,
|
||||||
|
use_xmoe=False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if not use_xmoe:
|
||||||
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
|
||||||
|
else:
|
||||||
|
self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
|
||||||
|
wg = torch.empty(num_experts, 16)
|
||||||
|
torch.nn.init.orthogonal_(wg, gain=0.32)
|
||||||
|
self.register_parameter("wg", torch.nn.Parameter(wg))
|
||||||
|
self.use_fp32 = use_fp32
|
||||||
|
self.second_expert_policy = second_expert_policy
|
||||||
|
self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping
|
||||||
|
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
|
||||||
|
self.batch_prioritized_routing = batch_prioritized_routing
|
||||||
|
self.use_xmoe = use_xmoe
|
||||||
|
|
||||||
|
def forward(self, input, mask=None): # type: ignore
|
||||||
|
if self.use_xmoe:
|
||||||
|
input = self.wg_reduction(input)
|
||||||
|
with torch.no_grad():
|
||||||
|
wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
|
||||||
|
self.wg.mul_(1.5 / wg_norm)
|
||||||
|
logits = self._cosine(input, self.wg)
|
||||||
|
logits = self._make_finite(logits)
|
||||||
|
else:
|
||||||
|
logits = self.wg(input)
|
||||||
|
return top2gating(
|
||||||
|
logits,
|
||||||
|
mask,
|
||||||
|
use_fp32=self.use_fp32,
|
||||||
|
second_expert_policy=self.second_expert_policy,
|
||||||
|
normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping,
|
||||||
|
eval_mode=not self.training,
|
||||||
|
moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
|
||||||
|
batch_prioritized_routing=self.batch_prioritized_routing,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _cosine(self, mat1, mat2, eps=1e-4):
|
||||||
|
assert mat1.dim() == 2
|
||||||
|
assert mat2.dim() == 2
|
||||||
|
# mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
|
||||||
|
mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
|
||||||
|
return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
|
||||||
|
|
||||||
|
def _make_finite(self, scores):
|
||||||
|
ok = scores.isfinite()
|
||||||
|
if not ok.all():
|
||||||
|
# NaNs here can break the assignment algorithm
|
||||||
|
scores[~ok] = scores[ok].min()
|
||||||
|
return scores
|
85
torchscale/model/BEiT3.py
Normal file
85
torchscale/model/BEiT3.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding
|
||||||
|
from torchscale.component.multiway_network import MultiwayWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class BEiT3(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.args = args
|
||||||
|
assert args.multiway
|
||||||
|
assert args.vocab_size > 0
|
||||||
|
assert not args.share_encoder_input_output_embed
|
||||||
|
self.text_embed = TextEmbedding(
|
||||||
|
args.vocab_size,
|
||||||
|
args.encoder_embed_dim
|
||||||
|
)
|
||||||
|
self.vision_embed = VisionEmbedding(
|
||||||
|
args.img_size,
|
||||||
|
args.patch_size,
|
||||||
|
args.in_chans,
|
||||||
|
args.encoder_embed_dim,
|
||||||
|
contain_mask_token=True,
|
||||||
|
prepend_cls_token=True
|
||||||
|
)
|
||||||
|
embed_positions = MultiwayWrapper(
|
||||||
|
args,
|
||||||
|
PositionalEmbedding(
|
||||||
|
args.max_source_positions,
|
||||||
|
args.encoder_embed_dim
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
self.encoder = Encoder(
|
||||||
|
args,
|
||||||
|
embed_tokens=None,
|
||||||
|
embed_positions=embed_positions,
|
||||||
|
output_projection=None,
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
textual_tokens=None,
|
||||||
|
visual_tokens=None,
|
||||||
|
text_padding_position=None,
|
||||||
|
vision_masked_position=None,
|
||||||
|
):
|
||||||
|
assert textual_tokens is not None or visual_tokens is not None
|
||||||
|
|
||||||
|
if textual_tokens is None:
|
||||||
|
x = self.vision_embed(visual_tokens, vision_masked_position)
|
||||||
|
encoder_padding_mask = None
|
||||||
|
multiway_split_position = -1
|
||||||
|
elif visual_tokens is None:
|
||||||
|
x = self.text_embed(textual_tokens)
|
||||||
|
encoder_padding_mask = text_padding_position
|
||||||
|
multiway_split_position = 0
|
||||||
|
else:
|
||||||
|
x1 = self.vision_embed(visual_tokens, vision_masked_position)
|
||||||
|
multiway_split_position = x1.size(1)
|
||||||
|
x2 = self.text_embed(textual_tokens)
|
||||||
|
x = torch.cat([x1, x2], dim=1)
|
||||||
|
|
||||||
|
if text_padding_position is not None:
|
||||||
|
encoder_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
|
||||||
|
text_padding_position
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_padding_mask = None
|
||||||
|
|
||||||
|
encoder_out = self.encoder(
|
||||||
|
src_tokens=None,
|
||||||
|
encoder_padding_mask=encoder_padding_mask,
|
||||||
|
token_embeddings=x,
|
||||||
|
multiway_split_position=multiway_split_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
return encoder_out
|
0
torchscale/model/__init__.py
Normal file
0
torchscale/model/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user