torchscale released

pull/1/head
shumingma 2022-11-23 08:21:58 +07:00
parent 41f6ee5687
commit ede048831f
47 changed files with 5230 additions and 29 deletions

162
.gitignore vendored

@ -348,3 +348,165 @@ MigrationBackup/
# Ionide (cross platform F# VS Code tools) working folder
.ionide/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

@ -1,11 +1,176 @@
# OpenScale - Transformers at (any) Scale
# TorchScale - A Library for Transformers at (Any) Scale
Fundamental research to improve modeling generality and capability, as well as training stability and efficiency of scaling Transformers.
<p>
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
<a href="https://pypi.org/project/torchscale"><img alt="MIT License" src="https://badge.fury.io/py/torchscale.svg" /></a>
</p>
TorchScale is a PyTorch library that allows researchers and developeres to scale up Transformers efficiently and effectively.
It has the implemetention of fundamental research to improve modeling generality and capability, as well as training stability and efficiency of scaling Transformers.
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423)
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
## News
- November, 2022: TorchScale 0.1.1 released
## Installation
To install:
```
pip install torchscale
```
Alternatively, you can develop it locally:
```
git clone https://github.com/microsoft/torchscale.git
cd torchscale
pip install -e .
```
## Getting Started
It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
```python
>>> from torchscale.architecture.config import EncoderConfig
>>> from torchscale.architecture.encoder import Encoder
>>> config = EncoderConfig(vocab_size=64000)
>>> model = Encoder(config)
>>> print(model)
```
We also support the `Decoder` architecture and the `EncoderDecoder` architecture:
```python
# Creating a decoder model
>>> from torchscale.architecture.config import DecoderConfig
>>> from torchscale.architecture.decoder import Decoder
>>> config = DecoderConfig(vocab_size=64000)
>>> decoder = Decoder(config)
>>> print(decoder)
# Creating a encoder-decoder model
>>> from torchscale.architecture.config import EncoderDecoderConfig
>>> from torchscale.architecture.encoder_decoder import EncoderDecoder
>>> config = EncoderDecoderConfig(vocab_size=64000)
>>> encdec = EncoderDecoder(config)
>>> print(encdec)
```
## Examples
We have the examples of how to use TorchScale in the following scenarios/tasks:
- Language
* [Decoder/GPT](examples/fairseq/README.md#example-gpt-pretraining)
* [Encoder-Decoder/Neural Machine Translation](examples/fairseq/README.md#example-machine-translation)
* [Encoder/BERT](examples/fairseq/README.md#example-bert-pretraining)
- Vision
* ViT/BEiT [In progress]
- Speech
- Multimodal
* [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
## Results
### Stability Evaluation
<p align="center">
<img src="./assets/convergence.png" width="800"/>
</p>
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
### Scaling-up Experiments
<p align="center">
<img src="./assets/scaling_curve.png" width="800"/>
</p>
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
## Acknowledgments
Some implementations in TorchScale are either adapted from or inspired by the [FairSeq](https://github.com/facebookresearch/fairseq) repository and the [UniLM](https://github.com/microsoft/unilm) repository.
## Citations
If you find this repository useful, please consider citing our work:
```
@article{deepnet,
author = {Hongyu Wang and
Shuming Ma and
Li Dong and
Shaohan Huang and
Dongdong Zhang and
Furu Wei},
title = {{DeepNet}: Scaling Transformers to 1,000 Layers},
journal = {CoRR},
volume = {abs/2203.00555},
year = {2022},
}
```
```
@article{magneto,
author = {Hongyu Wang and
Shuming Ma and
Shaohan Huang and
Li Dong and
Wenhui Wang and
Zhiliang Peng and
Yu Wu and
Payal Bajaj and
Saksham Singhal and
Alon Benhaim and
Barun Patra and
Zhun Liu and
Vishrav Chaudhary and
Xia Song and
Furu Wei},
title = {Foundation Transformers},
journal = {CoRR},
volume = {abs/2210.06423},
year = {2022}
}
```
```
@article{xmoe,
author = {Zewen Chi and
Li Dong and
Shaohan Huang and
Damai Dai and
Shuming Ma and
Barun Patra and
Saksham Singhal and
Payal Bajaj and
Xia Song and
Furu Wei},
title = {On the Representation Collapse of Sparse Mixture of Experts},
journal = {CoRR},
volume = {abs/2204.09179},
year = {2022}
}
```
## Contributing
@ -19,7 +184,7 @@ provided by the bot. You will only need to do this once across all repos using o
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
## Trademarks
@ -27,4 +192,4 @@ This project may contain trademarks or logos for projects, products, or services
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
Any use of third-party trademarks or logos are subject to those third-party's policies.

@ -1,25 +1,25 @@
# TODO: The maintainer of this repo has not yet edited this file
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
- **No CSS support:** Fill out this template with information about how to file issues and get help.
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
# Support
## How to file issues and get help
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
## Microsoft Support Policy
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
# TODO: The maintainer of this repo has not yet edited this file
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
- **No CSS support:** Fill out this template with information about how to file issues and get help.
- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
# Support
## How to file issues and get help
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
## Microsoft Support Policy
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

@ -0,0 +1,233 @@
# Example: Integration with FairSeq
## Setup
```bash
# Install the repo as a package:
git clone https://github.com/msranlp/torchscale.git
cd torchscale
pip install -e .
pip install git+https://github.com/shumingma/fairseq.git@moe
pip install git+https://github.com/shumingma/infinibatch.git
pip install iopath
pip install --upgrade numpy
```
## Example: BERT Pretraining
### Data Format
We use a [streaming dataloader](https://github.com/microsoft/infinibatch) to read the data on-the-fly from the disk. It requires the data sharded into multiple small files (e.g. 10K lines per file), as well as a JSON file to contain some meta data and the paths to these files.
The overall data directory should be organized as follows:
```
Data/
├── json/
│ ├── train.json
│ └── valid.json
├── shard/
│ ├── train/
│ │ ├── 00000.txt
│ │ ├── 00001.txt
│ │ └── ...
│ └── valid/
│ ├── 00000.txt
│ ├── 00001.txt
│ └── ...
├── dict.txt
└── sentencepiece.bpe.model
```
We recommend that each sharded data files contains no more than 10K lines with one sentence per line, and two documents should be separated with an empty line.
```
Document 1 Line 1
Document 1 Line 2
Document 1 Line 3
Document 2 Line 1
Document 2 Line 2
...
```
Also, the JSON file should be in the format like this:
```
[
{
"source": [
"shard/train/00000.txt",
"shard/train/00001.txt",
...
],
"source_lang": "en",
"weight": 1.0
}
]
```
### Training Command
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 train.py ${PATH_TO_DATA} \
--task pretraining \
--tokens-per-sample 512 \
--mask-prob 0.15 \
--span-length 3.0 \
--leave-unmasked-prob 0.0 \
--random-token-prob 0.0 \
--criterion masked_lm \
--arch mlm_base \
--share-encoder-input-output-embed \
--required-batch-size-multiple 8 \
--spm-model ${PATH_TO_DATA}/sentencepiece.bpe.model \
--dict-file ${PATH_TO_DATA}/dict.txt \
--optimizer adam \
--adam-betas '(0.9,0.98)' \
--adam-eps 1e-6 \
--clip-norm 2.0 \
--lr-scheduler polynomial_decay \
--lr 0.0005 \
--warmup-updates 10000 \
--total-num-update 125000 \
--max-update 125000 \
--max-sentences 32 \
--update-freq 1 \
--log-format simple \
--log-interval 100 \
--disable-validation \
--save-interval-updates 5000 \
--no-epoch-checkpoints \
--fp16 \
--fp16-init-scale 4 \
--fp16-scale-window 256 \
--min-loss-scale 0.0001 \
--seed 1 \
--save-dir ${PATH_TO_CKPT} \
--ddp-backend=no_c10d \
--distributed-no-spawn \
--reset-dataloader \
--batch-read-ahead 10000 \
--rel-pos-buckets 32 \
--max-rel-pos 128 \
--deepnorm
```
## Example: GPT Pretraining
### Data Format
We use the format as in the FairSeq's [language modeling example](https://github.com/facebookresearch/fairseq/tree/main/examples/language_model#1-preprocess-the-data).
### Dense Model
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
${PATH_TO_DATA} \
--num-workers 2 \
--activation-fn gelu \
--share-decoder-input-output-embed \
--validate-interval-updates 1000 \
--save-interval-updates 1000 \
--no-epoch-checkpoints \
--memory-efficient-fp16 \
--fp16-init-scale 4 \
--arch lm_base \
--task language_modeling \
--sample-break-mode none \
--tokens-per-sample 128 \
--optimizer adam --adam-betas "(0.9, 0.98)" \
--adam-eps 1e-08 \
--clip-norm 0.0 \
--lr 5e-4 \
--lr-scheduler polynomial_decay \
--warmup-updates 750 \
--dropout 0.1 \
--attention-dropout 0.1 \
--weight-decay 0.01 \
--batch-size 4 \
--update-freq 1 \
--required-batch-size-multiple 1 \
--total-num-update 50000 \
--max-update 50000 \
--seed 1 \
--ddp-backend=c10d
```
### Sparse (MoE) Model
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
${PATH_TO_DATA} \
--num-workers 2 \
--activation-fn gelu \
--share-decoder-input-output-embed \
--validate-interval-updates 1000 \
--save-interval-updates 1000 \
--no-epoch-checkpoints \
--memory-efficient-fp16 \
--fp16-init-scale 4 \
--arch lm_base \
--task language_modeling \
--sample-break-mode none \
--tokens-per-sample 128 \
--optimizer adam --adam-betas "(0.9, 0.98)" \
--adam-eps 1e-08 \
--clip-norm 0.0 \
--lr 5e-4 \
--lr-scheduler polynomial_decay \
--warmup-updates 750 \
--dropout 0.1 \
--attention-dropout 0.1 \
--weight-decay 0.01 \
--batch-size 4 \
--update-freq 1 \
--required-batch-size-multiple 1 \
--total-num-update 50000 \
--max-update 50000 \
--seed 1 \
--ddp-backend=no_c10d \
--moe-expert-count 2 --moe-freq 2 \
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
--moe-eval-capacity-token-fraction -1.0 \
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
--use-xmoe
```
## Example: Machine Translation
### Data Format
We follow the FairSeq's [neural machine translation example](https://github.com/facebookresearch/fairseq/tree/main/examples/translation#training-a-new-model) to preprocess the data.
### Dense Model
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
${PATH_TO_DATA} \
--arch mt_base --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--dropout 0.3 --weight-decay 0.0001 \
--max-tokens 4096 --fp16
```
### Sparse (MoE) Model
```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
${PATH_TO_DATA} \
--arch mt_base --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
--dropout 0.3 --weight-decay 0.0001 \
--moe-expert-count 2 --moe-freq 2 \
--moe-gating-use-fp32 --moe-second-expert-policy random --moe-normalize-gate-prob-before-dropping \
--moe-eval-capacity-token-fraction -1.0 \
--criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \
--use-xmoe \
--max-tokens 4096 --fp16
```

@ -0,0 +1,7 @@
import models
import tasks
from fairseq_cli.generate import cli_main
if __name__ == "__main__":
cli_main()

@ -0,0 +1,7 @@
import models
import tasks
from fairseq_cli.interactive import cli_main
if __name__ == "__main__":
cli_main()

@ -0,0 +1,33 @@
import argparse
import importlib
import os
MODEL_REGISTRY = {}
MODEL_DATACLASS_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_NAME_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
# automatically import any Python files in the models/ directory
models_dir = os.path.dirname(__file__)
for file in os.listdir(models_dir):
path = os.path.join(models_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
model_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("models." + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group("Named architectures")
group_archs.add_argument(
"--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]
)
group_args = parser.add_argument_group("Additional command-line arguments")
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser

@ -0,0 +1,459 @@
import math
import logging
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
)
from fairseq.modules import PositionalEmbedding
from fairseq.models.squad import SQuADHead
from torch import Tensor
from omegaconf import II
from .machine_translation import MTEncoder as Encoder
from torchscale.architecture.config import EncoderConfig
from apex.normalization import FusedLayerNorm as LayerNorm
DEFAULT_MAX_SOURCE_POSITIONS = 1024
logger = logging.getLogger(__name__)
@dataclass
class BertConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
attention_dropout: float = field(
default=0.0, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
encoder_embed_dim: int = field(
default=512, metadata={"help": "encoder embedding dimension"}
)
encoder_output_dim: int = field(
default=512, metadata={"help": "encoder output dimension"}
)
encoder_input_dim: int = field(
default=512, metadata={"help": "encoder input dimension"}
)
encoder_ffn_embed_dim: int = field(
default=2048, metadata={"help": "encoder embedding dimension for FFN"}
)
encoder_layers: int = field(default=6, metadata={"help": "num encoder layers"})
encoder_attention_heads: int = field(
default=8, metadata={"help": "num encoder attention heads"}
)
encoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each encoder block"}
)
no_encoder_final_norm: bool = field(
default=False,
metadata={"help": "don't add an extra layernorm after the last encoder block"},
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={
"help": "if set, disables positional embeddings (outside self attention)"
},
)
share_encoder_input_output_embed: bool = field(
default=False, metadata={"help": "share encoder input and output embeddings"}
)
encoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the encoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
encoder_layerdrop: float = field(
default=0.0, metadata={"help": "LayerDrop probability for encoder"}
)
encoder_layers_to_keep: Optional[str] = field(
default=None,
metadata={
"help": "which layers to *keep* when pruning as a comma-separated list"
},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
}
)
max_source_positions: int = field(
default=1024, metadata={"help": "max source positions"}
)
pooler_activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use for pooler layer"}
)
pooler_dropout: float = field(
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
)
# options from other parts of the config
# add_bos_token: bool = II("task.add_bos_token")
# tokens_per_sample: int = II("task.tokens_per_sample")
tpu: bool = II("common.tpu")
rel_pos_buckets: int = field(
default=0, metadata={"help": ""}
)
max_rel_pos: int = field(
default=0, metadata={"help": ""}
)
moe_freq: int = field(
default=0,
metadata={
"help": "Frequency at which we insert MoE Transformer layers"
},
)
moe_expert_count: int = field(
default=0,
metadata={
"help": "Number of experts in each MoE Layer"
}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={
"help": "Use FP32 computations in MoE top2 gating function"
}
)
moe_second_expert_policy: str = field(
default='sampling',
metadata={
"help": "policy for second expert, options: all/sampling/random"
}
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
}
)
moe_expert_ffn_dim: Optional[int] = field(
default=None,
metadata={
"help": "MoE expert FFN dimension"
}
)
moe_top1_expert: Optional[bool] = field(
default=False,
metadata={
"help": "Use top1 gate instead of top2"
}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
}
)
moe_normalize_expert_grad: Optional[str] = field(
default='world_size',
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
}
)
record_a2a_perf_stats: Optional[bool] = field(
default=False, metadata={"help": "records all to all perf stats during distributed training"}
)
dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
)
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
@register_model("mlm", dataclass=BertConfig)
class BertModel(BaseFairseqModel):
def __init__(self, args, encoder):
super().__init__()
self.args = args
self.encoder = encoder
self.padding_idx = self.encoder.embed_tokens.padding_idx
self.classification_heads = nn.ModuleDict()
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
args.max_source_positions = getattr(
args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.dictionary, args.encoder_embed_dim
)
embed_positions = (
PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim,
task.dictionary.pad(),
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
lm_head = cls.build_lm_head(
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
)
config = EncoderConfig()
config.override(args)
encoder = Encoder(
config,
embed_tokens=embed_tokens,
embed_positions=embed_positions,
output_projection=lm_head,
is_encoder_decoder=False,
dictionary=task.dictionary,
)
return cls(args, encoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
return embed_tokens
@classmethod
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
return LMHead(embed_dim, output_dim, activation_fn, weight)
def output_layer(self, features, masked_tokens=None):
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = ClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
def register_question_answering_head(self, name, num_classes=None):
self.classification_heads[name] = SQuADHead(
self.args.encoder_embed_dim,
)
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict.
current_head_names = (
[] if not hasattr(self, 'classification_heads')
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
if getattr(self.args, 'load_checkpoint_heads', False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
'deleting classification head ({}) from checkpoint '
'not present in current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
):
logger.warning(
'deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
def forward(
self,
src_tokens=None,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
masked_tokens=None,
**kwargs):
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
x, extra = encoder_out["encoder_out"], encoder_out
x = x.transpose(0, 1)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
elif not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class LMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the masked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
@register_model_architecture("mlm", "mlm_base")
def base_unilm_architecture(args):
if hasattr(args, "encoder_final_norm"):
args.no_encoder_final_norm = not args.encoder_final_norm
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
# args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_encoder_input_output_embed = getattr(
args, "share_encoder_input_output_embed", True
)
args.encoder_output_dim = getattr(
args, "encoder_output_dim", args.encoder_embed_dim
)
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
# Model training is not stable without this
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True

@ -0,0 +1,357 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import options, utils
from fairseq import distributed_utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
)
from fairseq.modules import PositionalEmbedding
from torchscale.architecture.decoder import Decoder
from torchscale.architecture.config import DecoderConfig
from omegaconf import II
DEFAULT_MAX_TARGET_POSITIONS = 1024
import logging
logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
attention_dropout: float = field(
default=0.0, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
relu_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
decoder_input_dim: int = field(
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_attention_heads: int = field(
default=8, metadata={"help": "num decoder attention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each decoder block"}
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={
"help": "if set, disables positional embeddings (outside self attention)"
},
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
}
)
moe_freq: int = field(
default=0,
metadata={
"help": "Frequency at which we insert MoE Transformer layers"
},
)
moe_expert_count: int = field(
default=0,
metadata={
"help": "Number of experts in each MoE Layer"
}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={
"help": "Use FP32 computations in MoE top2 gating function"
}
)
moe_second_expert_policy: str = field(
default='sampling',
metadata={
"help": "policy for second expert, options: all/sampling/random"
}
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
}
)
moe_expert_ffn_dim: Optional[int] = field(
default=None,
metadata={
"help": "MoE expert FFN dimension"
}
)
moe_top1_expert: Optional[bool] = field(
default=False,
metadata={
"help": "Use top1 gate instead of top2"
}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
}
)
moe_normalize_expert_grad: Optional[str] = field(
default='world_size',
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
}
)
record_a2a_perf_stats: Optional[bool] = field(
default=False, metadata={"help": "records all to all perf stats during distributed training"}
)
dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
)
use_xmoe: Optional[bool] = field(
default=False,
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("common.tpu")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
fp16: bool = II("common.fp16")
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
ddp_backend: str = II("distributed_training.ddp_backend")
world_size: int = II("distributed_training.distributed_world_size")
distributed_rank: int = II("distributed_training.distributed_rank")
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
rel_pos_buckets: Optional[int] = field(
default=0,
)
max_rel_pos: Optional[int] = field(
default=0,
)
@register_model("lm", dataclass=LanguageConfig)
class LanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
embed_positions = (
PositionalEmbedding(
args.max_target_positions,
args.decoder_embed_dim,
task.dictionary.pad(),
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
)
if (
getattr(args, 'moe_freq', 0) > 0
and (
getattr(args, 'fp16', False)
and not getattr(args, 'memory_efficient_fp16', False)
and getattr(args, 'ddp_backend', None) != "fully_sharded"
)
):
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = DecoderConfig()
config.override(args)
decoder = LMDecoder(
config,
embed_tokens,
embed_positions,
output_projection,
is_encoder_decoder=False,
dictionary=task.dictionary,
)
return cls(args, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return Embedding(len(dictionary), embed_dim, dictionary.pad())
class LMDecoder(Decoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
def max_positions(self):
return self.embed_positions.max_positions
def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
@register_model_architecture("lm", "lm_base")
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True

@ -0,0 +1,450 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.distributed import utils as dist_utils, fsdp_wrap
from fairseq import distributed_utils
from fairseq import checkpoint_utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding
from fairseq.modules import (
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder
from torch import Tensor
import logging
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("mt")
class TranslationModel(FairseqEncoderDecoderModel):
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension (extra linear layer '
'if different from decoder embed dim')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
parser.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activations at each layer, which saves GPU '
'memory usage at the cost of some additional compute')
parser.add_argument('--offload-activations', action='store_true',
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
help='iterative PQ quantization noise at training time')
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
'minimum number of params for a layer to be wrapped with FSDP() when '
'training with --ddp-backend=fully_sharded. Smaller values will '
'improve memory efficiency, but may make torch.distributed '
'communication less efficient due to smaller input sizes. This option '
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
'--offload-activations are passed.'
)
)
# args for mixture-of-expert layers
parser.add_argument('--moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer layers')
parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer encoder layers')
parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer decoder layers')
parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0,
help='Number of experts in each MoE Layer')
parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true',
help="Use FP32 computations in MoE top2 gating function")
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
help="policy for second expert, options: all/sampling/random")
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
help="MoE Expert FFN dimension")
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
help="Use top1 gate instead of top2")
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
help="Fraction of tokens as capacity during validation" + \
"if set to negative, use same as training. range: (0.0, 1.0].")
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
help="Don't route padding tokens to any expert")
parser.add_argument('--use-xmoe', default=False, action='store_true',
help="Enable X-Moe")
parser.add_argument('--freeze-moe', default=False, action='store_true',
help="Freeze MoE Params")
parser.add_argument('--deepnorm', default=False, action='store_true',
help="Enable DeepNorm")
parser.add_argument('--subln', default=False, action='store_true',
help="Enable SubLN")
parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='')
# args for pseudo-MoE layers
parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0,
help="FFN embed dim of alternate pseudo-MoE blocks")
parser.add_argument('--rel-pos-buckets', type=int, default=0,
help='')
parser.add_argument('--max-rel-pos', type=int, default=0,
help='')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
args.ddp_rank = distributed_utils.get_data_parallel_rank()
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = cls.build_embedding(
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder_embed_positions = (
PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim,
src_dict.pad(),
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
decoder_embed_positions = (
PositionalEmbedding(
args.max_target_positions,
args.decoder_embed_dim,
tgt_dict.pad(),
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
decoder_embed_tokens.weight.shape[1],
decoder_embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = decoder_embed_tokens.weight
else:
output_projection = torch.nn.Linear(
decoder_embed_dim, len(tgt_dict), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
)
encoder = cls.build_encoder(
args,
encoder_embed_tokens,
encoder_embed_positions,
src_dict,
)
decoder = cls.build_decoder(
args,
decoder_embed_tokens,
decoder_embed_positions,
output_projection,
tgt_dict,
)
if not args.share_all_embeddings:
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
return cls(args, encoder, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
config = EncoderConfig()
config.override(args)
return MTEncoder(
config,
embed_tokens,
embed_positions,
is_encoder_decoder=True,
dictionary=dictionary,
)
@classmethod
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
config = DecoderConfig()
config.override(args)
return MTDecoder(
config,
embed_tokens,
embed_positions,
output_projection,
is_encoder_decoder=True,
dictionary=dictionary,
)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = False,
features_only: bool = False,
**kwargs
):
encoder_out = self.encoder(
src_tokens,
return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
class MTEncoder(Encoder, FairseqEncoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask,
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
}
def max_positions(self):
return self.embed_positions.max_positions
@register_model_architecture("mt", "mt_base")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.is_moe = getattr(args, "is_moe", False)
args.selected_expert_count = getattr(args, "selected_expert_count", 2)

@ -0,0 +1,32 @@
import argparse
import importlib
import os
# register dataclass
TASK_DATACLASS_REGISTRY = {}
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
# automatically import any Python files in the tasks/ directory
tasks_dir = os.path.dirname(__file__)
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
task_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("tasks." + task_name)
# expose `task_parser` for sphinx
if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group("Task name")
# fmt: off
group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``')
# fmt: on
group_args = parser.add_argument_group("Additional command-line arguments")
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser

@ -0,0 +1,78 @@
import math
import re
import sys
import time
import torch
from infinibatch.iterators import CheckpointableIterator
from . import utils
class BaseBatchGen(CheckpointableIterator):
"""
This is a base class for batch generators that use infinibatch
"""
def __init__(self):
self._iter = None
self.epoch = 1
self.next_epoch_idx = 1
self.sharded_checkpoint = True
self.should_close_after_finished = True
def _build_iter(self):
"""
Build infinibatch iterator and assign to self._iter
"""
raise NotImplementedError()
def _move_to_tensor(self, batch):
def to_tensor(x):
return torch.tensor(x)
return utils.apply_to_sample(to_tensor, batch)
@property
def iterator(self):
if self._iter is None:
raise NotImplementedError("_build_iter() must called first")
return self._iter
def __iter__(self):
if self._iter is None:
raise NotImplementedError("_build_iter() must called first")
return self._iter
def __next__(self):
return next(self._iter)
def setstate(self, value):
self._iter.setstate(value)
def getstate(self):
return self._iter.getstate()
def close(self):
self._iter.close()
def __len__(self) -> int:
return 819200000
def next_epoch_itr(
self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True
):
return self
def end_of_epoch(self) -> bool:
return False
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
return self.getstate()
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.setstate(state_dict)
@property
def first_batch(self):
return "DUMMY"

@ -0,0 +1,308 @@
import glob
import os
import torch
import numpy as np
import time
import json
import random
import itertools
import copy
from infinibatch import iterators
from .basic_loader import BaseBatchGen
from .utils import NativeCheckpointableIterator, WeightIterator
class MLMLoader(BaseBatchGen):
def __init__(
self,
args,
dataset,
dictionary,
tokenizer,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
):
super().__init__()
self.args = args
self.data = dataset.data
self.data_dir = dataset.data_dir
self.shuffle = dataset.shuffle
self.dictionary = dictionary
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.max_sentences = max_sentences
self.max_positions = max_positions
self.tokens_per_sample = args.tokens_per_sample
self.sample_break_mode = args.sample_break_mode
self.ignore_invalid_inputs = ignore_invalid_inputs
self.required_batch_size_multiple = required_batch_size_multiple
self.seed = str(seed)
self.num_shards = num_shards
self.shard_id = shard_id
self.batch_read_ahead = args.batch_read_ahead
self._build_iter()
def _build_iter(self):
tokenized_lines = self._multilingual_tokenize()
self.padded_batches = self._batchify(tokenized_lines)
prefetch_batches = iterators.PrefetchIterator(
self.padded_batches,
buffer_size=10000,
buffer_in_main_process=True,
log_empty_buffer_warning=True and self.shard_id == 0,
)
prefetch_batches = iterators.MapIterator(
prefetch_batches, self._move_to_tensor
)
self._iter = prefetch_batches
def _multilingual_tokenize(self):
multilingual_iters = []
weights = []
for data in self.data:
multilingual_iters.append(
self._tokenize(data)
)
if 'weight' in data:
weights.append(float(data['weight']))
else:
weights.append(int(data['count']))
if len(multilingual_iters) == 1:
return multilingual_iters[0]
sampling_iterator = WeightIterator(weights)
control_iterator = NativeCheckpointableIterator(sampling_iterator)
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
return tokenized_lines
def _tokenize(self, data):
'''
data:
{
'source': list[Path],
'source_lang': str,
'count': int,
'weight': float,
'name': str,
}
'''
dataset = list(
zip(
data['source'],
itertools.repeat(data['source_lang']),
)
)
if self.shuffle:
chunk_files = \
iterators.InfinitePermutationSourceIterator(
dataset,
seed=self.seed,
shuffle=self.shuffle,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
else:
chunk_files = \
iterators.ChunkedSourceIterator(
dataset,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
return tokenized_lines
def _batchify(self, lines):
if self.max_sentences is not None:
if self.batch_read_ahead > 0:
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
else:
def dynamic_batch_size(sample):
lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator(
lines,
read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size,
shuffle=self.shuffle,
seed=self.seed,
)
def collate(batch):
batch_size = len(batch)
mlm_source_max_length = max([len(x[0]) for x in batch])
mlm_target_max_length = max([len(x[1]) for x in batch])
s2s_source_max_length = max([len(x[2]) for x in batch])
s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
ret_batch = {
'net_input': {
'src_tokens': mlm_source_ids.astype(np.int64),
},
'target': mlm_target_ids.astype(np.int64),
'nsentences': batch_size,
'ntokens': sum([len(x[0]) for x in batch]),
}
return ret_batch
padded_batches = iterators.MapIterator(
batches, collate
)
return padded_batches
def _prepare(self, _random, doc):
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
def _mask_lm(self, _random, doc):
def mask_tokens():
return f"<mask>"
length = len(doc)
mask_tokens_num = int(length * self.args.mask_prob)
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
possible_mask_positions = _random.sample(range(length), k=mask_tokens_num)
possible_mask_positions = sorted(possible_mask_positions)
nonmasked_tokens = copy.deepcopy(doc)
masked_tokens = [self.dictionary.pad() for _ in range(len(doc))]
for position in possible_mask_positions:
# masked_tokens.append(nonmasked_tokens[position])
masked_tokens[position] = nonmasked_tokens[position]
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
return nonmasked_tokens, masked_tokens
def _span_corruption(self, _random, doc):
def mask_tokens(i):
return f"<mask_{i}>"
length = len(doc)
noise_tokens_num = int(length * self.args.mask_prob)
noise_tokens_num = min(max(noise_tokens_num, 1), length - 1)
noise_spans_num = int(noise_tokens_num / self.args.span_length)
noise_spans_num = max(noise_spans_num, 1)
nonnoise_tokens_num = length - noise_tokens_num
if noise_spans_num == 1:
noise_split_positions = [0, noise_tokens_num]
else:
possible_split_positions = list(range(1, noise_tokens_num))
_random.shuffle(possible_split_positions)
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
possible_insert_positions = list(range(nonnoise_tokens_num))
_random.shuffle(possible_insert_positions)
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
nonnoise_spans, noise_spans = [], []
last_end = 0
for i in range(noise_spans_num):
start_pos = noise_insert_positions[i] + noise_split_positions[i]
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
mask_id = self.dictionary.indices[mask_tokens(i)]
if getattr(self.args, "remove_target_sentinel", False):
noise_spans.append(doc[start_pos:end_pos])
else:
noise_spans.append([mask_id] + doc[start_pos:end_pos])
if getattr(self.args, "remove_source_sentinel", False):
nonnoise_spans.extend(doc[last_end:start_pos])
else:
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
last_end = end_pos
nonnoise_spans.extend(doc[last_end:])
noise_spans = sum(noise_spans, [])
return nonnoise_spans, noise_spans
def _read_from_files(self, source_file, source_lang):
# data = []
file_path = os.path.join(self.data_dir, source_file)
if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True)
return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f:
lines = f.read().strip().split('\n')
doc = [self.dictionary.bos()]
for line in lines:
if line == "":
if self.sample_break_mode == 'complete_doc':
# data.append(doc)
yield doc
doc = [self.dictionary.bos()]
continue
tokenized_line = self.tokenizer.EncodeAsPieces(line)
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
if len(tokenized_id) > self.tokens_per_sample:
continue
if len(doc) + len(tokenized_id) > self.tokens_per_sample:
# data.append(doc)
yield doc
doc = [self.dictionary.bos()]
doc.extend(tokenized_id)
if len(doc) > 1 and len(doc) <= self.tokens_per_sample:
# data.append(doc)
yield doc
# return data

@ -0,0 +1,82 @@
import os
import gzip
import numpy as np
from random import Random
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
import collections
from infinibatch import iterators
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if isinstance(x, np.ndarray):
return f(x)
elif isinstance(x, collections.OrderedDict):
# OrderedDict has attributes that needs to be preserved
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
od.__dict__ = x.__dict__
return od
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
elif isinstance(x, tuple):
return tuple(_apply(x) for x in x)
elif isinstance(x, set):
return {_apply(x) for x in x}
else:
return x
return _apply(sample)
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def __init__(self, iterable: Iterable):
self._input_iterable = iterable
self.setstate(None)
def getstate(self) -> Dict:
return {'num_items_yielded': self._num_items_yielded}
def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
def __next__(self):
item = next(self._iterator)
self._num_items_yielded += 1
return item
def close(self):
pass
class WeightIterator(object):
def __init__(self, weights, seed):
self.weights = weights
self.seed = seed
self.control_index = list(range(len(weights)))
self.setstate(None)
def __iter__(self):
return self
def getstate(self):
return {"random_state": self._random_state}
def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None
self._random = None # this will trigger the lazy initialization in self.__next__
def __next__(self):
if self._random is None:
self._random = Random(self.seed)
if self._random_state is not None:
self._random.setstate(self._random_state)
idx = self._random.choices(self.control_index, self.weights)[0]
self._random_state = self._random.getstate()
return idx
def close(self):
pass

@ -0,0 +1,207 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import logging
import os
from argparse import Namespace
import json
from omegaconf import MISSING, II, OmegaConf
from typing import Any
import numpy as np
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task
from .data.mlm_loader import MLMLoader
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
import sentencepiece as spm
logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass
class PretrainingConfig(FairseqDataclass):
data: str = field(
default=MISSING,
metadata={
"help": "colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner"
},
)
sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field(
default="complete",
metadata={
"help": 'If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
"of sentence, but may include multiple sentences per sample. "
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.'
},
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
mask_prob: float = field(
default=0.15,
metadata={"help": "probability of replacing a token with mask"},
)
leave_unmasked_prob: float = field(
default=0.1,
metadata={"help": "probability that a masked token is unmasked"},
)
random_token_prob: float = field(
default=0.1,
metadata={"help": "probability of replacing a token with a random token"},
)
freq_weighted_replacement: bool = field(
default=False,
metadata={"help": "sample random replacement words based on word frequencies"},
)
mask_whole_words: bool = field(
default=False,
metadata={"help": "mask whole words; you may also want to set --bpe"},
)
mask_multiple_length: int = field(
default=1,
metadata={"help": "repeat the mask indices multiple times"},
)
mask_stdev: float = field(
default=0.0,
metadata={"help": "stdev of the mask length"},
)
shorten_method: SHORTEN_METHOD_CHOICES = field(
default="none",
metadata={
"help": "if not none, shorten sequences that exceed --tokens-per-sample"
},
)
shorten_data_split_list: str = field(
default="",
metadata={
"help": "comma-separated list of dataset splits to apply shortening to, "
'e.g., "train,valid" (default: all dataset splits)'
},
)
seed: int = II("common.seed")
span_length: float = field(
default=3.0,
metadata={"help": "average span length for masking"},
)
remove_source_sentinel: bool = field(
default=False,
metadata={"help": "remove the source sentinel for the span corruption task"},
)
remove_target_sentinel: bool = field(
default=False,
metadata={"help": "remove the target sentinel for the span corruption task"},
)
batch_read_ahead: int = field(
default=100000,
metadata={"help": "batch read ahead size for infinibatch"},
)
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
spm_model: str = field(
default="",
metadata={
"help": "sentencepice model to tokenize the data"
},
)
dict_file: str = field(
default="",
metadata={
"help": ""
},
)
@register_task("pretraining", dataclass=PretrainingConfig)
class PLMTask(FairseqTask):
def __init__(self, cfg, dictionary, tokenizer):
super().__init__(cfg)
self.cfg = cfg
self.dictionary = dictionary
self.tokenizer = tokenizer
self.seed = cfg.seed
self.mask_idx = dictionary.index("<mask>")
@classmethod
def setup_task(cls, cfg, **kwargs):
paths = utils.split_paths(cfg.data)
assert len(paths) > 0
if cfg.dict_file != "":
dictionary = Dictionary.load(cfg.dict_file)
else:
dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
# add mask token
dictionary.add_symbol("<mask>")
for i in range(100):
dictionary.add_symbol(f"<mask_{i}>")
dictionary.pad_to_multiple_(cfg.required_batch_size_multiple)
logger.info("dictionary: {} types".format(len(dictionary)))
# tokenizer = SentencepieceBPE(Namespace(sentencepiece_model=cfg.spm_model))
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(cfg.spm_model)
return cls(cfg, dictionary, tokenizer)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
self.datasets[split] = {
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
'data_dir': self.cfg.data,
'shuffle': True if split == 'train' else False,
}
self.datasets[split] = Namespace(**self.datasets[split])
def dataset(self, split):
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
return self.datasets[split]
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
return MLMLoader(
self.cfg,
dataset,
self.dictionary,
self.tokenizer,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary

@ -0,0 +1,8 @@
import models
import tasks
from fairseq_cli.train import cli_main
if __name__ == "__main__":
cli_main()

@ -0,0 +1,75 @@
import torch
import warnings
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
import torch.distributed as dist
import math
@torch.no_grad()
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
params = list(filter(grad_exists, params))
grads, expert_grads, base_expert_grads, sharded_grads = [], [], [], []
denom = math.sqrt(max(dist.get_global_world_size(), moe_expert_count))
for p in params:
if hasattr(p, "expert"):
expert_grads.append(p.grad.detach() / denom)
elif hasattr(p, "base_expert"):
base_expert_grads.append(p.grad.detach())
elif hasattr(p, "_is_sharded"):
sharded_grads.append(p.grad.detach())
else:
grads.append(p.grad.detach())
if len(grads) == 0:
if len(params) > 0:
total_norm = params[0].new_tensor(0.0)
else:
total_norm = torch.tensor(0.0)
elif len(grads) == 1:
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32)
else:
if multi_tensor_l2norm_available:
total_norm = multi_tensor_total_norm(grads)
else:
if torch.cuda.is_available():
warnings.warn(
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; "
"you may get better performance by installing NVIDIA's apex library"
)
device = torch.cuda.current_device()
elif grads[0].device.type == "xla":
device = grads[0].device
else:
device = torch.device("cpu")
total_norm = torch.norm(
torch.stack(
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads]
)
)
# calculate split_norm and all_reduce with other workers
norms = [total_norm]
for split_grads in [expert_grads, sharded_grads]:
if len(split_grads) == 0:
continue
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]))
if dist.is_initialized():
split_norm.pow_(2)
dist.all_reduce(split_norm)
split_norm.sqrt_()
norms.append(split_norm)
if len(norms) > 1:
total_norm = torch.norm(torch.stack(norms))
if aggregate_norm_fn is not None:
total_norm = aggregate_norm_fn(total_norm)
if max_norm > 0:
max_norm = float(max_norm)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
for g in grads + expert_grads + sharded_grads + base_expert_grads:
g.mul_(clip_coef)
return total_norm

@ -0,0 +1,25 @@
from io import open
from setuptools import find_packages, setup
setup(
name="torchscale",
version="0.1.1",
author="TorchScale Team",
author_email="Shuming.Ma@microsoft.com",
description="Transformers at any scale",
long_description=open("README.md", "r", encoding='utf-8').read(),
long_description_content_type="text/markdown",
keywords="Transformers at any scale",
license="MIT",
url="https://github.com/msranlp/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]),
install_requires=['apex',
'torch>=1.8',
'fairscale==0.4.0',
'timm==0.4.12'],
python_requires='>=3.8.0',
classifiers=[
'Programming Language :: Python :: 3',
],
)

@ -0,0 +1,33 @@
import pytest
from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
import torch
testcases = [
{},
{"vocab_size": 64000},
{"activation_fn": "relu"},
{"drop_path_rate": 0.1},
{"decoder_normalize_before": False},
{"no_scale_embedding": False},
{"layernorm_embedding": True},
{"rel_pos_buckets": 32, "max_rel_pos": 256},
{"deepnorm": True, "subln": False, "decoder_normalize_before": False},
{"bert_init": True},
{"multiway": True},
{"share_decoder_input_output_embed": True},
{"checkpoint_activations": True},
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_decoder(args):
config = DecoderConfig(**args)
model = Decoder(config)
prev_output_tokens = torch.ones(2, 10)
token_embeddings = torch.rand(2, 10, config.decoder_embed_dim)
model(
prev_output_tokens=prev_output_tokens,
token_embeddings=token_embeddings,
features_only=True,
)

@ -0,0 +1,28 @@
import pytest
from torchscale.architecture.config import EncoderConfig
from torchscale.architecture.encoder import Encoder
import torch
testcases = [
{},
{"vocab_size": 64000},
{"activation_fn": "relu"},
{"drop_path_rate": 0.1},
{"encoder_normalize_before": False},
{"no_scale_embedding": False},
{"layernorm_embedding": True},
{"rel_pos_buckets": 32, "max_rel_pos": 256},
{"deepnorm": True, "subln": False, "encoder_normalize_before": False},
{"bert_init": True},
{"multiway": True},
{"share_encoder_input_output_embed": True},
{"checkpoint_activations": True},
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_encoder(args):
config = EncoderConfig(**args)
model = Encoder(config)
token_embeddings = torch.rand(2, 10, config.encoder_embed_dim)
model(src_tokens=None, token_embeddings=token_embeddings)

@ -0,0 +1,43 @@
import pytest
from torchscale.architecture.config import EncoderDecoderConfig
from torchscale.architecture.encoder_decoder import EncoderDecoder
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding
import torch
testcases = [
{},
{"vocab_size": 64000},
{"activation_fn": "relu"},
{"drop_path_rate": 0.1},
{"encoder_normalize_before": False, "decoder_normalize_before": False},
{"no_scale_embedding": False},
{"layernorm_embedding": True},
{"rel_pos_buckets": 32, "max_rel_pos": 256},
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False},
{"bert_init": True},
{"multiway": True},
{"share_decoder_input_output_embed": True},
{"share_all_embeddings": True},
{"checkpoint_activations": True},
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_decoder(args):
config = EncoderDecoderConfig(**args)
model = EncoderDecoder(
config,
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
)
src_tokens = torch.ones(2, 20).long()
prev_output_tokens = torch.ones(2, 10).long()
model(
src_tokens=src_tokens,
prev_output_tokens=prev_output_tokens,
features_only=True,
)

@ -0,0 +1,160 @@
class EncoderConfig(object):
def __init__(self, **kwargs):
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
self.use_xmoe = kwargs.pop("use_xmoe", True)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Vision
self.img_size = kwargs.pop("img_size", 224)
self.patch_size = kwargs.pop("patch_size", 16)
self.in_chans = kwargs.pop("in_chans", 3)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
if self.deepnorm:
self.encoder_normalize_before = False
if self.subln:
self.encoder_normalize_before = True
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
class DecoderConfig(object):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
self.use_xmoe = kwargs.pop("use_xmoe", True)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
if self.deepnorm:
self.decoder_normalize_before = False
if self.subln:
self.decoder_normalize_before = True
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
class EncoderDecoderConfig(object):
def __init__(self, **kwargs):
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12)
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.attention_dropout = kwargs.pop("attention_dropout", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
self.use_xmoe = kwargs.pop("use_xmoe", True)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
if self.deepnorm:
self.encoder_normalize_before = False
self.decoder_normalize_before = False
if self.subln:
self.encoder_normalize_before = True
self.decoder_normalize_before = True
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)

@ -0,0 +1,447 @@
import math
import torch
import torch.nn as nn
import numpy as np
from fairscale.nn import checkpoint_wrapper, wrap
from apex.normalization import FusedLayerNorm as LayerNorm
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.droppath import DropPath
from torchscale.architecture.utils import init_bert_params
from torchscale.component.relative_position_bias import RelativePositionBias
class DecoderLayer(nn.Module):
def __init__(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False,
):
super().__init__()
self.args = args
self.embed_dim = args.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
if not is_encoder_decoder:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
if not self.is_moe_layer:
self.ffn = self.build_ffn(
self.embed_dim,
self.args,
)
else:
if args.moe_top1_expert:
gate = Top1Gate(
self.embed_dim,
args.moe_expert_count,
use_fp32=args.moe_gating_use_fp32,
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
else:
gate = Top2Gate(
self.embed_dim,
args.moe_expert_count,
args.moe_gating_use_fp32,
args.moe_second_expert_policy,
args.moe_normalize_gate_prob_before_dropping,
args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = LayerNorm(self.embed_dim)
if args.deepnorm:
if is_encoder_decoder:
self.alpha = math.pow(3.0 * args.decoder_layers, 0.25)
else:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
else:
self.alpha = 1.0
if args.subln:
self.ffn_layernorm = LayerNorm(self.ffn_dim)
else:
self.ffn_layernorm = None
def build_ffn(self, embed_dim, args):
return FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.subln,
)
def build_self_attention(self, embed_dim, args):
return MultiheadAttention(
args,
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
encoder_decoder_attention=False,
subln=args.subln,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
args,
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
self_attention=False,
encoder_decoder_attention=True,
subln=args.subln,
)
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
self_attn_rel_pos=None,
cross_attn_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
attn_mask=self_attn_mask,
rel_pos=self_attn_rel_pos,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=None,
rel_pos=cross_attn_rel_pos,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
if not self.is_moe_layer:
x = self.ffn(x)
l_aux = None
else:
x = x.transpose(0, 1)
x, l_aux = self.moe_layer(x)
x = x.transpose(0, 1)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, attn, None, l_aux
class Decoder(nn.Module):
def __init__(
self,
args,
embed_tokens=None,
embed_positions=None,
output_projection=None,
is_encoder_decoder=False,
**kwargs
):
super().__init__(**kwargs)
self.args = args
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.embed_tokens = embed_tokens
self.embed_positions = embed_positions
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.layers = nn.ModuleList([])
moe_freq = args.moe_freq
for i in range(args.decoder_layers):
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
self.layers.append(
self.build_decoder_layer(
args,
depth=i,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder,
)
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
self.output_projection = output_projection
self.self_attn_relative_position = None
self.cross_attn_relative_position = None
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
self.self_attn_relative_position = RelativePositionBias(
num_buckets=args.rel_pos_buckets,
max_distance=args.max_rel_pos,
n_heads=args.decoder_attention_heads,
)
if is_encoder_decoder:
self.cross_attn_relative_position = RelativePositionBias(
num_buckets=args.rel_pos_buckets,
max_distance=args.max_rel_pos,
n_heads=args.decoder_attention_heads,
)
if args.bert_init:
self.apply(init_bert_params)
if args.deepnorm:
if is_encoder_decoder:
init_scale = math.pow(12.0 * args.decoder_layers, 0.25)
else:
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
p.data.div_(init_scale)
if args.subln:
if is_encoder_decoder:
init_scale = math.sqrt(math.log(args.decoder_layers * 3))
else:
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
for name, p in self.named_parameters():
if 'encoder_attn' in name:
continue
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
p.data.mul_(init_scale)
def build_output_projection(
self,
args,
):
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = self.embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
)
return output_projection
def build_decoder_layer(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
):
layer = DecoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer
def forward_embedding(
self,
tokens,
token_embedding=None,
incremental_state=None,
):
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(tokens, incremental_state=incremental_state)
if incremental_state is not None:
tokens = tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
if token_embedding is None:
token_embedding = self.embed_tokens(tokens)
x = embed = self.embed_scale * token_embedding
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def forward(
self,
prev_output_tokens,
self_attn_padding_mask=None,
encoder_out=None,
incremental_state=None,
features_only=False,
return_all_hiddens=False,
token_embeddings=None,
**kwargs
):
# embed tokens and positions
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
x = x.transpose(0, 1)
# relative postion
self_attn_rel_pos_bias = None
slen = prev_output_tokens.size(1)
if self.self_attn_relative_position is not None:
self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(1),
qlen=slen,
klen=slen
)
if incremental_state is not None:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
cross_attn_rel_pos_bias = None
if self.cross_attn_relative_position is not None:
cross_attn_rel_pos_bias = self.cross_attn_relative_position(
batch_size=x.size(1),
qlen=slen,
klen=encoder_out["encoder_out"].size(0),
)
if incremental_state is not None:
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[:, -1:, :]
# decoder layers
inner_states = [x]
if encoder_out is None:
l_aux = []
else:
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
for idx, layer in enumerate(self.layers):
if incremental_state is None:
self_attn_mask = torch.triu(
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1
)
else:
self_attn_mask = None
if idx not in incremental_state:
incremental_state[idx] = {}
x, layer_attn, _, l_aux_i = layer(
x,
encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
incremental_state[idx] if incremental_state is not None else None,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
self_attn_rel_pos=self_attn_rel_pos_bias,
cross_attn_rel_pos=cross_attn_rel_pos_bias,
)
l_aux.append(l_aux_i)
inner_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
x = x.transpose(0, 1)
if not features_only:
x = self.output_layer(x)
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]}
def output_layer(self, features):
return self.output_projection(features)

@ -0,0 +1,367 @@
import math
import torch
import torch.nn as nn
import numpy as np
from fairscale.nn import checkpoint_wrapper, wrap
from apex.normalization import FusedLayerNorm as LayerNorm
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
from torchscale.component.droppath import DropPath
from torchscale.architecture.utils import init_bert_params
from torchscale.component.relative_position_bias import RelativePositionBias
class EncoderLayer(nn.Module):
def __init__(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
):
super().__init__()
self.args = args
self.embed_dim = args.encoder_embed_dim
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.normalize_before = args.encoder_normalize_before
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.encoder_ffn_embed_dim
if not self.is_moe_layer:
self.ffn = MultiwayWrapper(
args,
self.build_ffn(
self.embed_dim,
self.args,
)
)
else:
assert not self.args.multiway
if args.moe_top1_expert:
gate = Top1Gate(
self.embed_dim,
args.moe_expert_count,
use_fp32=args.moe_gating_use_fp32,
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
else:
gate = Top2Gate(
self.embed_dim,
args.moe_expert_count,
args.moe_gating_use_fp32,
args.moe_second_expert_policy,
args.moe_normalize_gate_prob_before_dropping,
args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
if args.deepnorm:
if is_encoder_decoder:
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
else:
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
else:
self.alpha = 1.0
def build_ffn(self, embed_dim, args):
return FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.subln,
)
def build_self_attention(self, embed_dim, args):
return MultiheadAttention(
args,
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
encoder_decoder_attention=False,
subln=args.subln,
)
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
encoder_padding_mask,
attn_mask=None,
rel_pos=None
):
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
attn_mask=attn_mask,
rel_pos=rel_pos,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
if not self.is_moe_layer:
x = self.ffn(x)
l_aux = None
else:
x = x.transpose(0, 1)
x, l_aux = self.moe_layer(x)
x = x.transpose(0, 1)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, l_aux
class Encoder(nn.Module):
def __init__(
self,
args,
embed_tokens=None,
embed_positions=None,
output_projection=None,
is_encoder_decoder=False,
**kwargs
):
self.args = args
super().__init__(**kwargs)
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
embed_dim = args.encoder_embed_dim
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.embed_tokens = embed_tokens
self.embed_positions = embed_positions
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1)
else:
self.layernorm_embedding = None
self.layers = nn.ModuleList([])
moe_freq = args.moe_freq
for i in range(args.encoder_layers):
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
self.layers.append(
self.build_encoder_layer(
args,
depth=i,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder
)
)
self.num_layers = len(self.layers)
if args.encoder_normalize_before:
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim))
else:
self.layer_norm = None
if args.rel_pos_buckets > 0 and args.max_rel_pos > 0:
self.relative_position = RelativePositionBias(
num_buckets=args.rel_pos_buckets,
max_distance=args.max_rel_pos,
n_heads=args.encoder_attention_heads,
)
else:
self.relative_position = None
if args.bert_init:
self.apply(init_bert_params)
if args.deepnorm:
if is_encoder_decoder:
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
else:
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
p.data.div_(init_scale)
if args.subln:
if is_encoder_decoder:
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
else:
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
p.data.mul_(init_scale)
def build_output_projection(
self,
args,
):
if args.share_encoder_input_output_embed:
assert args.encoder_embedding_type == 'language'
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = self.embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.encoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5
)
return output_projection
def build_encoder_layer(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
):
layer = EncoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer
def forward_embedding(
self,
src_tokens,
token_embedding=None,
):
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
if src_tokens is not None:
x = embed + self.embed_positions(src_tokens)
else:
x = embed + self.embed_positions(x)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def forward(
self,
src_tokens,
encoder_padding_mask=None,
return_all_hiddens=False,
token_embeddings=None,
multiway_split_position=None,
features_only=False,
**kwargs
):
assert src_tokens is not None or token_embeddings is not None
if encoder_padding_mask is None:
if src_tokens is not None:
encoder_padding_mask = torch.zeros_like(
src_tokens,
device=src_tokens.device
).bool()
else:
encoder_padding_mask = torch.zeros(
[token_embeddings.size(0), token_embeddings.size(1)],
device=token_embeddings.device
).bool()
if multiway_split_position is not None:
assert self.args.multiway
self.apply(set_split_position(multiway_split_position))
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
x = x.transpose(0, 1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
rel_pos_bias = None
if self.relative_position is not None:
rel_pos_bias = self.relative_position(
batch_size=x.size(1),
qlen=x.size(0),
klen=x.size(0)
)
l_aux = []
for layer in self.layers:
x, l_aux_i = layer(
x, encoder_padding_mask=encoder_padding_mask,
rel_pos=rel_pos_bias
)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
l_aux.append(l_aux_i)
if self.layer_norm is not None:
x = self.layer_norm(x)
if not features_only and self.output_projection is not None:
x = self.output_projection(x)
return {
"encoder_out": x,
"encoder_embedding": encoder_embedding,
"encoder_padding_mask": encoder_padding_mask,
"encoder_states": encoder_states,
"l_aux": l_aux,
}

@ -0,0 +1,61 @@
import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.decoder import Decoder
class EncoderDecoder(nn.Module):
def __init__(
self,
args,
encoder_embed_tokens=None,
encoder_embed_positions=None,
decoder_embed_tokens=None,
decoder_embed_positions=None,
output_projection=None,
**kwargs
):
super().__init__()
self.args = args
if args.share_all_embeddings:
args.share_decoder_input_output_embed = True
self.encoder = Encoder(
args,
encoder_embed_tokens,
encoder_embed_positions,
is_encoder_decoder=True,
**kwargs
)
if args.share_all_embeddings and decoder_embed_tokens is None:
decoder_embed_tokens = self.encoder.embed_tokens
self.decoder = Decoder(
args,
decoder_embed_tokens,
decoder_embed_positions,
output_projection,
is_encoder_decoder=True,
**kwargs
)
def forward(
self,
src_tokens,
prev_output_tokens,
return_all_hiddens=False,
features_only=False,
**kwargs
):
encoder_out = self.encoder(
src_tokens,
return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
return_all_hiddens=return_all_hiddens,
)
return decoder_out

@ -0,0 +1,30 @@
import torch.nn as nn
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.multiway_network import MultiwayNetwork
def init_bert_params(module):
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
if isinstance(module.q_proj, MultiwayNetwork):
normal_(module.q_proj.A.weight.data)
normal_(module.q_proj.B.weight.data)
normal_(module.k_proj.A.weight.data)
normal_(module.k_proj.B.weight.data)
normal_(module.v_proj.A.weight.data)
normal_(module.v_proj.B.weight.data)
else:
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)

@ -0,0 +1,16 @@
from timm.models.layers import drop_path
import torch.nn as nn
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
return 'p={}'.format(self.drop_prob)

@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VisionLanguageEmbedding(nn.Module):
def __init__(
self,
text_embed,
vision_embed
):
super().__init__()
self.text_embed = text_embed
self.vision_embed = vision_embed
def forward(
self,
textual_tokens,
visual_tokens,
**kwargs
):
if textual_tokens is None:
return self.vision_embed(visual_tokens)
if visual_tokens is None:
return self.text_embed(textual_tokens)
x1 = self.vision_embed(visual_tokens)
x2 = self.text_embed(textual_tokens)
return torch.cat([x1, x2], dim=1)
class VisionEmbedding(nn.Module):
""" Image to Patch Embedding
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
contain_mask_token=False,
prepend_cls_token=False
):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if contain_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
else:
self.mask_token = None
if prepend_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
else:
self.cls_token = None
def forward(
self,
x,
masked_position=None,
**kwargs
):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
batch_size, seq_len, _ = x.size()
if masked_position is not None:
assert self.mask_token is not None
mask_token = self.mask_token.expand(batch_size, seq_len, -1)
w = masked_position.unsqueeze(-1).type_as(mask_token)
x = x * (1 - w) + mask_token * w
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
return x
class TextEmbedding(nn.Embedding):
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
self._fill_padding_idx_with_zero()
class PositionalEmbedding(nn.Embedding):
def forward(
self,
x,
positions=None,
**kwargs,
):
if positions is None:
# being consistent with Fairseq, which starts from 2.
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0)
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

@ -0,0 +1,119 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
class set_torch_seed(object):
def __init__(self, seed):
assert isinstance(seed, int)
self.rng_state = self.get_rng_state()
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_rng_state(self):
state = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
state["cuda_rng_state"] = torch.cuda.get_rng_state()
return state
def set_rng_state(self, state):
torch.set_rng_state(state["torch_rng_state"])
if torch.cuda.is_available():
torch.cuda.set_rng_state(state["cuda_rng_state"])
def __enter__(self):
return self
def __exit__(self, *exc):
self.set_rng_state(self.rng_state)
def make_experts(args, embed_dim, expert_ffn_dim):
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
expert_list = []
ddp_rank = args.ddp_rank
start_seed = torch.randint(1000000, (1,)).item()
# at least as many experts than gpus
if args.moe_expert_count >= world_size:
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}'
local_moe_expert_count = args.moe_expert_count // world_size
for i in range(local_moe_expert_count):
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
expert_list.append(
FeedForwardNetwork(
embed_dim,
expert_ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.subln
)
)
else:
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}'
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
expert_list.append(
FeedForwardNetwork(
embed_dim,
expert_ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.subln
)
)
experts = nn.ModuleList(expert_list)
return experts
def get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return lambda x: F.gelu(x.float()).type_as(x)
else:
raise NotImplementedError
class FeedForwardNetwork(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
subln=False
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True)
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
if self.ffn_layernorm is not None:
self.ffn_layernorm.reset_parameters()
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
x = self.fc1(x)
x = self.activation_fn(x)
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = x.view(x_shape)
x = self.dropout_module(x)
return x

@ -0,0 +1,117 @@
import math
import torch
from torch import nn
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from .multiway_network import MultiwayWrapper
class MultiheadAttention(nn.Module):
def __init__(
self,
args,
embed_dim,
num_heads,
dropout=0.0,
self_attention=False,
encoder_decoder_attention=False,
subln=False,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert self.self_attention ^ self.encoder_decoder_attention
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
def reset_parameters(self):
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(
self,
query,
key,
value,
incremental_state=None,
key_padding_mask=None,
attn_mask=None,
rel_pos=None,
):
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
src_len, key_bsz, _ = key.size()
assert key_bsz == bsz, f"{query.size(), key.size()}"
assert value is not None
assert src_len, bsz == value.shape[:2]
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if incremental_state is not None:
if "prev_key" in incremental_state:
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
k = torch.cat([prev_key, k], dim=1)
v = torch.cat([prev_value, v], dim=1)
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
if attn_mask is not None:
attn_weights = torch.nan_to_num(attn_weights)
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if rel_pos is not None:
rel_pos = rel_pos.view(attn_weights.size())
attn_weights = attn_weights + rel_pos
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
attn = torch.bmm(attn_probs, v)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
if self.inner_attn_ln is not None:
attn = self.inner_attn_ln(attn)
attn = self.out_proj(attn)
attn_weights = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
return attn, attn_weights

@ -0,0 +1,39 @@
import copy
import torch
import torch.nn as nn
def MultiwayWrapper(args, module, dim=0):
if args.multiway:
return MultiwayNetwork(module, dim=dim)
return module
def set_split_position(position):
def apply_fn(module):
if hasattr(module, 'split_position'):
module.split_position = position
return apply_fn
class MultiwayNetwork(nn.Module):
def __init__(self, module, dim=0):
super().__init__()
self.dim = dim
self.A = module
self.B = copy.deepcopy(module)
self.B.reset_parameters()
self.split_position = -1
def forward(self, x, **kwargs):
if self.split_position == -1:
return self.A(x, **kwargs)
if self.split_position == 0:
return self.B(x, **kwargs)
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim)
# x1, x2 = x[:self.split_position], x[self.split_position:]
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
return torch.cat([y1, y2], dim=self.dim)

@ -0,0 +1,79 @@
import math
import torch
import torch.nn as nn
class RelativePositionBias(nn.Module):
def __init__(
self,
bidirectional=True,
num_buckets=32,
max_distance=128,
n_heads=12
):
super().__init__()
self.bidirectional = bidirectional
self.num_buckets = num_buckets
self.max_distance = max_distance
self.n_heads = n_heads
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128
):
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def compute_bias(
self,
qlen,
klen,
step=None
):
step = 0 if step is None else step
context_position = torch.arange(step, step + qlen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values
def forward(
self,
batch_size,
qlen,
klen,
step=None
):
# shape (batch * num_heads, qlen, klen)
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)

@ -0,0 +1,310 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# NOTE: This is a mirror of the code in
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
import logging
import time
from typing import Any, Tuple, cast
import torch
import torch.distributed as dist
from torch import Tensor
from torch.nn import Module, ModuleList
try:
from fairseq.modules.moe import MOELayer
has_fairseq = True
Base = MOELayer
except ModuleNotFoundError:
Base = Module
has_fairseq = False
try:
# To enable Tutel MoE optimizations:
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@v0.1.x
from tutel import moe as tutel_moe
has_tutel, fused_cumsum_sub_one = True, tutel_moe.fast_cumsum_sub_one
except ModuleNotFoundError:
has_tutel, fused_cumsum_sub_one = False, lambda mask: torch.cumsum(mask, dim=0) - 1
logger = logging.getLogger(__name__)
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
if torch.distributed.is_initialized():
dist.all_to_all_single(output, input, group=group)
else:
assert group is None
output = input
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)]
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]
class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
gate (torch.nn.Module):
gate network
expert (torch.nn.Module):
expert network
"""
def __init__(
self,
gate,
experts,
args
):
if has_fairseq:
super(Base, self).__init__()
else:
super().__init__()
self.gate = gate
if type(experts) == ModuleList:
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.expert_group = get_moe_group(args.moe_expert_count)
self.all2all_group = get_all2all_group(args.moe_expert_count)
self.world_size = dist.get_world_size(group=self.expert_group)
self.all2all_size = dist.get_world_size(group=self.all2all_group)
for p in experts.parameters():
p.expert = True # type: ignore
self.num_local_experts = len(self.experts)
self.args = args
self.in_generation = False
self.a2a_cuda_event_intervals = []
self.a2a_cpu_time_ms = 0.0
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
input = input[0]
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
if input_padding_mask is not None:
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken"
assert input_padding_mask.shape[0] == input.shape[0]
assert input_padding_mask.shape[1] == input.shape[1]
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
# Implement Algorithm 2 from GShard paper.
d_model = input.shape[2]
# Pad to expected batch size
input_shape = list(input.shape)
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0)
# This indicates that --batch-size or --max-sentences is not specified
if expected_bsz is None:
expected_bsz = 0
# Note: Padding is not necessary at generation time at present
# because all DDP workers process the same batch. Also, batch size at generation time
# can be different from that present in the checkpoint state
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz:
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})")
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
padded_input = torch.zeros(
(expected_bsz, input_shape[1], input_shape[2]),
dtype=input.dtype, layout=input.layout, device=input.device)
padded_input[:input_shape[0], :, :] = input
input = padded_input
padded_input_padding_mask = torch.ones(
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device
)
if input_padding_mask is not None:
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask
else:
padded_input_padding_mask[:input_shape[0], :] = False
input_padding_mask = padded_input_padding_mask
# Reshape into S tokens by dropping sequence dimension.
reshaped_input = input.reshape(-1, d_model)
reshaped_input_shape = reshaped_input.shape
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
# Pro of --max-tokens: more flexible for MT variable sequence lengths
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
if expected_bsz == 0:
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device)
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
expected_dim = int(expected_dim.item())
padded_input = torch.zeros(
(expected_dim, reshaped_input_shape[1]),
dtype=input.dtype, layout=input.layout, device=input.device)
padded_input[:reshaped_input_shape[0], :] = reshaped_input
reshaped_input = padded_input
padded_input_padding_mask = torch.ones(
(expected_dim,), dtype=torch.bool, device=padded_input.device
)
if reshaped_input_padding_mask is not None:
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask
else:
padded_input_padding_mask[:reshaped_input_shape[0]] = False
reshaped_input_padding_mask = padded_input_padding_mask
if has_tutel:
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask)
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask)
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S
E, C, S = dispatch_mask.size()
M = reshaped_input.size(1)
assert reshaped_input.size() == (S, M)
# einsum("sec,sm->ecm")
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M
if self.all2all_size > 1:
dispatched_input = self.all_to_all_wrapper(dispatched_input)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model)
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs += [expert(chunk)]
expert_output = torch.cat(expert_outputs, dim=1)
if self.all2all_size > 1:
expert_output = self.all_to_all_wrapper(expert_output)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model)
if has_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M))
else:
# einsum("sec,ecm->sm")
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M))
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
combined_output = combined_output[:reshaped_input_shape[0], :]
combined_output = combined_output.reshape(input.shape)
combined_output = combined_output[:input_shape[0], :, :]
self.record_all_to_all_stats()
return combined_output, l_aux
def prepare_for_inference_(self):
self.in_generation = True
def all_to_all_wrapper(self, input: Tensor):
dummy_a2a = getattr(self.args, 'dummy_a2a', False)
if dummy_a2a:
input = input.contiguous()
output = input.detach().clone()
return input
# always record times, since it is not a lot of overhead
# if we do not log it we simply clear it off in record_all_to_all_stats
cuda_start = torch.cuda.Event(enable_timing=True)
cuda_end = torch.cuda.Event(enable_timing=True)
cpu_start = time.time() * 1000
cuda_start.record()
output = _AllToAll.apply(self.all2all_group, input)
cuda_end.record()
cpu_end = time.time() * 1000
self.a2a_cpu_time_ms += (cpu_end - cpu_start)
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
return output
def record_all_to_all_stats(self):
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False)
if record_a2a_perf_stats:
torch.cuda.synchronize()
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
a2a_cuda_time_ms = 0.0
for ev_start, ev_end in self.a2a_cuda_event_intervals:
a2a_cuda_time_ms += ev_start.elapsed_time(ev_end)
self.metadata["all_to_all_cuda_time_ms"] = a2a_cuda_time_ms
# reset stats
self.a2a_cpu_time_ms = 0.0
self.a2a_cuda_event_intervals = []

@ -0,0 +1,459 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf
# Code is inspired by Top2GatingOnLogits from lingvo:
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
# NOTE: This is a mirror of the code in
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
from typing import Callable, Dict, Tuple, Optional
import math
import torch
from torch import Tensor
import torch.nn.functional as F
from .moe_layer import has_tutel, fused_cumsum_sub_one
# use a fixed temperature to compute balance loss
TEMPERATURE_FOR_L_UAX = 0.07
# maximum capacity of 1 expert as a fraction of number of tokens in the batch
# Note: setting this to 1.0 causes inference to significantly slow down
EVAL_CAPACITY_TOKEN_FRACTION = 0.25
# logging
SAMPLE_FRACTION = 0.2
def top1gating(
logits: torch.Tensor,
input_mask: Optional[torch.Tensor] = None,
use_fp32=False,
capacity_factor=1.0,
eval_mode=False,
moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
use_xmoe=False,
gate_obj=None,
) -> Tuple[Tensor, Tensor, Tensor, Dict]:
"""Implements Top2Gating on logits."""
metadata = {}
if use_fp32:
orig_dtype = logits.dtype
logits = logits.float()
gates = F.softmax(logits, dim=1)
metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
else:
# capacity = capacity_factor * S/E
capacity = int(capacity_factor * math.ceil(num_tokens / num_experts))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
if input_mask is not None and input_mask.any():
nonpadding = ~ input_mask
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
# for logging (percent of tokens routed to each expert)
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
gates1_s = (gates * mask1).sum(dim=1)
# Compute locations in capacity buffer
locations1 = fused_cumsum_sub_one(mask1)
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.to(gates.dtype), dim=0)
l_aux = torch.mean(me * ce)
l_aux = l_aux * num_experts * num_experts
if has_tutel:
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ]
# Remove locations outside capacity from mask
mask1 = mask1 * torch.lt(locations1, capacity)
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Calculate combine_weights and dispatch_mask
gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
# locations1_sc = num_tokens * capacity
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
combine1_sec = torch.bmm(
# einsum("se,sc->sec")
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
)
dispatch_mask = combine1_sec.bool()
if use_fp32:
return l_aux, combine1_sec.to(orig_dtype), dispatch_mask, metadata
else:
return l_aux, combine1_sec, dispatch_mask, metadata
class Top1Gate(torch.nn.Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
gate = Top2Gate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""
wg: torch.nn.Linear
def __init__(
self,
model_dim: int,
num_experts: int,
use_fp32=False,
input_noise_type=None,
capacity_factor=1.0,
moe_eval_capacity_token_fraction=EVAL_CAPACITY_TOKEN_FRACTION,
use_xmoe=False,
) -> None:
# TODO: merge this to top2gate.py
#
super().__init__()
if not use_xmoe:
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
else:
self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
wg = torch.empty(num_experts, 16)
torch.nn.init.orthogonal_(wg, gain=0.32)
self.register_parameter("wg", torch.nn.Parameter(wg))
self.use_xmoe = use_xmoe
self.use_fp32 = use_fp32
self.input_noise_type = input_noise_type
self.capacity_factor = capacity_factor
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
def forward(self, input, mask=None): # type: ignore
if self.use_xmoe:
input = self.wg_reduction(input)
with torch.no_grad():
wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
self.wg.mul_(1.5 / wg_norm)
logits = self._cosine(input, self.wg)
logits = self._make_finite(logits)
else:
logits = self.wg(input)
return top1gating(
logits,
mask,
use_fp32=self.use_fp32,
capacity_factor=self.capacity_factor,
eval_mode=not self.training,
moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
use_xmoe=self.use_xmoe,
gate_obj=self,
)
def _make_finite(self, scores):
ok = scores.isfinite()
if not ok.all():
# NaNs here can break the assignment algorithm
scores[~ok] = scores[ok].min()
return scores
def _get_gating_temperature(self, eps=1e-4):
if self.gating_t.data.item() < eps:
return eps
return self.gating_t
def _cosine(self, mat1, mat2, eps=1e-4):
assert mat1.dim() == 2
assert mat2.dim() == 2
# mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
gumbel_map: Dict[torch.device, Callable] = {}
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) -> Tensor:
if unsqueeze_indices:
indices = indices.unsqueeze(-1)
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype)
output.scatter_(
len(output.shape) - 1, indices, 1
)
return output
def entropy(probs):
logits = torch.distributions.utils.probs_to_logits(probs)
p_log_p = probs * logits
return -p_log_p.sum(-1)
def top2gating(
logits: torch.Tensor,
input_mask: Optional[torch.Tensor] = None,
use_fp32=False,
second_expert_policy='sampling',
normalize_gate_prob_before_dropping=False,
eval_mode=False,
moe_eval_capacity_token_fraction=0.25,
batch_prioritized_routing=False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
metadata = {}
if use_fp32:
orig_dtype = logits.dtype
logits = logits.float()
gates = F.softmax(logits, dim=1)
metadata["entropy_gating"] = entropy(probs=gates).mean().detach()
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
if moe_eval_capacity_token_fraction > 0.0 and eval_mode:
capacity = math.ceil(moe_eval_capacity_token_fraction * num_tokens)
else:
# capacity = 2S/E
capacity = 2 * math.ceil(num_tokens / num_experts)
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
mask1 = one_hot(indices1_s, num_experts)
if second_expert_policy == 'sampling':
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
else:
logits_w_noise = logits
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1, keepdim=True)
mask2 = one_hot(indices2_s, num_experts)
gates1_s = (gates * mask1).sum(dim=1)
gates2_s = (gates * mask2).sum(dim=1)
if normalize_gate_prob_before_dropping:
# Normalize gate probabilities
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s = gates1_s / denom_s
gates2_s = gates2_s / denom_s
if second_expert_policy == 'random':
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
# Compute locations in capacity buffer
if input_mask is not None and input_mask.any():
nonpadding = ~ input_mask
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
if batch_prioritized_routing:
# if batch_prioritized_routing:
importance_scores = -1 * gates.max(dim=1)[0]
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2
else:
locations1 = fused_cumsum_sub_one(mask1)
locations2 = fused_cumsum_sub_one(mask2)
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.to(gates.dtype), dim=0)
l_aux = torch.mean(me * ce)
l_aux = l_aux * num_experts * num_experts
# for logging purposes
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
# Remove locations outside capacity from mask
mask1_, mask2_ = mask1, mask2
mask1 = mask1 * torch.lt(locations1, capacity)
mask2 = mask2 * torch.lt(locations2, capacity)
# for logging (percent of tokens routed to each expert)
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
metadata["expert1_balance_bottom"] = expert1_hist[-sample_count:].sum()
metadata["expert2_balance_top"] = expert2_hist[:sample_count].sum()
metadata["expert2_balance_bottom"] = expert2_hist[-sample_count:].sum()
if not normalize_gate_prob_before_dropping:
# Normalize gate probabilities
gates1_s = (gates * mask1).sum(dim=1)
gates2_s = (gates * mask2).sum(dim=1)
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s
if has_tutel:
locations1_s = torch.sum(locations1 * mask1_, dim=1)
locations2_s = torch.sum(locations2 * mask2_, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)
# Calculate combine_weights and dispatch_mask
gates1 = gates1_s.unsqueeze(-1) * mask1.to(gates1_s.dtype) # einsum("s,se->se")
gates2 = gates2_s.unsqueeze(-1) * mask2.to(gates2_s.dtype) # einsum("s,se->se")
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
combine1_sec = torch.bmm(
# einsum("se,sc->sec")
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
)
combine2_sec = torch.bmm(
# einsum("se,sc->sec")
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1)
)
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
if use_fp32:
return l_aux, combine_weights.to(orig_dtype), dispatch_mask, metadata
else:
return l_aux, combine_weights, dispatch_mask, metadata
class Top2Gate(torch.nn.Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
gate = Top2Gate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""
wg: torch.nn.Linear
def __init__(
self,
model_dim: int,
num_experts: int,
use_fp32=False,
second_expert_policy='sampling',
normalize_gate_prob_before_dropping=False,
moe_eval_capacity_token_fraction=0.25,
batch_prioritized_routing=False,
use_xmoe=False,
) -> None:
super().__init__()
if not use_xmoe:
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
else:
self.wg_reduction = torch.nn.Linear(model_dim, 16, bias=False)
wg = torch.empty(num_experts, 16)
torch.nn.init.orthogonal_(wg, gain=0.32)
self.register_parameter("wg", torch.nn.Parameter(wg))
self.use_fp32 = use_fp32
self.second_expert_policy = second_expert_policy
self.normalize_gate_prob_before_dropping = normalize_gate_prob_before_dropping
self.moe_eval_capacity_token_fraction = moe_eval_capacity_token_fraction
self.batch_prioritized_routing = batch_prioritized_routing
self.use_xmoe = use_xmoe
def forward(self, input, mask=None): # type: ignore
if self.use_xmoe:
input = self.wg_reduction(input)
with torch.no_grad():
wg_norm = self.wg.norm(p=2.0, dim=1, keepdim=True)
self.wg.mul_(1.5 / wg_norm)
logits = self._cosine(input, self.wg)
logits = self._make_finite(logits)
else:
logits = self.wg(input)
return top2gating(
logits,
mask,
use_fp32=self.use_fp32,
second_expert_policy=self.second_expert_policy,
normalize_gate_prob_before_dropping=self.normalize_gate_prob_before_dropping,
eval_mode=not self.training,
moe_eval_capacity_token_fraction=self.moe_eval_capacity_token_fraction,
batch_prioritized_routing=self.batch_prioritized_routing,
)
def _cosine(self, mat1, mat2, eps=1e-4):
assert mat1.dim() == 2
assert mat2.dim() == 2
# mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)
def _make_finite(self, scores):
ok = scores.isfinite()
if not ok.all():
# NaNs here can break the assignment algorithm
scores[~ok] = scores[ok].min()
return scores

@ -0,0 +1,85 @@
import torch
import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding
from torchscale.component.multiway_network import MultiwayWrapper
class BEiT3(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
assert args.multiway
assert args.vocab_size > 0
assert not args.share_encoder_input_output_embed
self.text_embed = TextEmbedding(
args.vocab_size,
args.encoder_embed_dim
)
self.vision_embed = VisionEmbedding(
args.img_size,
args.patch_size,
args.in_chans,
args.encoder_embed_dim,
contain_mask_token=True,
prepend_cls_token=True
)
embed_positions = MultiwayWrapper(
args,
PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim
),
dim=1,
)
self.encoder = Encoder(
args,
embed_tokens=None,
embed_positions=embed_positions,
output_projection=None,
is_encoder_decoder=False,
)
def forward(
self,
textual_tokens=None,
visual_tokens=None,
text_padding_position=None,
vision_masked_position=None,
):
assert textual_tokens is not None or visual_tokens is not None
if textual_tokens is None:
x = self.vision_embed(visual_tokens, vision_masked_position)
encoder_padding_mask = None
multiway_split_position = -1
elif visual_tokens is None:
x = self.text_embed(textual_tokens)
encoder_padding_mask = text_padding_position
multiway_split_position = 0
else:
x1 = self.vision_embed(visual_tokens, vision_masked_position)
multiway_split_position = x1.size(1)
x2 = self.text_embed(textual_tokens)
x = torch.cat([x1, x2], dim=1)
if text_padding_position is not None:
encoder_padding_mask = torch.cat(
[
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
text_padding_position
],
dim=1,
)
else:
encoder_padding_mask = None
encoder_out = self.encoder(
src_tokens=None,
encoder_padding_mask=encoder_padding_mask,
token_embeddings=x,
multiway_split_position=multiway_split_position,
)
return encoder_out