Compare commits

...

39 Commits
0.2.0 ... main

Author SHA1 Message Date
mrq
008f1b6d18 added compat flags because I guess the maintainer assumed no one was actually using the retnet and thinks they can change things willy nilly 2023-10-05 16:38:57 -05:00
mrq
ce77afe916 added arg to change RelPos's base 2023-10-05 16:02:08 -05:00
Shuming Ma
881d03079d
Merge pull request #70 from sunyt32/retnet-official
fix fairseq example
2023-09-29 12:59:56 +08:00
sunyt32
50174a3078 fix fairseq example 2023-09-29 03:50:24 +00:00
Shuming Ma
ab1d9d677a
Merge pull request #69 from sunyt32/retnet-official
Update new RetNet settings
2023-09-29 10:07:48 +08:00
sunyt32
05a9628309 fix bug 2023-09-28 17:39:26 +00:00
sunyt32
59fc5f7d3d fix bug 2023-09-28 17:05:53 +00:00
sunyt32
fd8234c2ac rollback variant name 2023-09-28 16:44:51 +00:00
sunyt32
7f07609361 update README.md 2023-09-28 14:26:56 +00:00
sunyt32
5c89ffbeea modify rms norm and value dim in retention 2023-09-28 14:24:37 +00:00
Li Dong
d1fefe9c22
rollback LN epsilon in retention
rollback 2c29de0fb3
2023-09-27 20:40:36 +08:00
Shuming Ma
258eda3308
Update vocab links 2023-08-11 16:46:37 +08:00
Shuming Ma
70e047a53b
Update README.md 2023-08-11 11:26:19 +08:00
Li Dong
8b07f19ba0
Update README.md 2023-08-10 13:15:42 +08:00
Shuming Ma
e2db7ae123
Merge pull request #51 from sunyt32/retnet-official
fix chunkwise inconsistency bug
2023-08-04 13:51:53 +08:00
sunyt32
0b1f113985 fix chunkwise inconsistency bug 2023-08-04 05:48:58 +00:00
Li Dong
0faee72d6f
Merge pull request #50 from wangmengzhi/main-2
Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward
2023-08-04 09:02:38 +08:00
wangmengzhi
7f0bf80a7e
Adding sqrt in the recurrent_forward of retnet to make it consistent with parallel_forward
Adding sqrt in the recurrent_forward of retnet can avoid numerical underflow thus improving consistency and performance. https://github.com/microsoft/torchscale/issues/47
2023-08-04 08:18:10 +08:00
Shuming Ma
7d231743f4
Merge pull request #46 from sunyt32/retnet-official
Update epsilon in retention
2023-08-02 14:07:44 +08:00
sunyt32
2c29de0fb3 Update epsilon in retention 2023-08-02 05:38:30 +00:00
Shuming Ma
5356b252c4 Update MT config 2023-07-31 09:17:03 -07:00
gitnlp
ea07735c7b
Update README.md 2023-07-27 12:33:20 +08:00
gitnlp
63a4f2df2e
Update README.md 2023-07-27 12:32:37 +08:00
gitnlp
3573227c88
Update README.md 2023-07-26 18:40:30 +08:00
gitnlp
f58c8247be
Update README.md 2023-07-26 18:39:23 +08:00
gitnlp
774003903e
Update README.md 2023-07-26 18:38:49 +08:00
Li Dong
bf65397b26 RetNet 2023-07-24 14:30:13 +08:00
Shuming Ma
89d8c6de9e Fix encdec config issue 2023-07-23 23:07:21 -07:00
Shuming Ma
2b101355d7
Merge pull request #33 from XingxingZhang/lm_sampling
support lm prefix computation in one go
2023-06-04 02:01:54 +08:00
Shuming Ma
859c4c6bcc
Merge pull request #31 from klae01/feature/pytest
add basic test
2023-06-04 02:01:29 +08:00
Shuming Ma
0cef83d675
Merge pull request #27 from JonathanRayner/patch-1
Bump timm version to latest
2023-06-04 02:01:04 +08:00
Xingxing Zhang
c766630327 support lm prefix computation in one go 2023-06-03 15:37:47 +00:00
klae01
d675712846 add basic test 2023-05-26 13:30:00 +09:00
Shuming Ma
b59b6f87b9
Merge pull request #30 from njb-ms/main
make pgs global
2023-04-26 12:12:55 +08:00
johan bjorck
a4d830b87d make num experts optional arg 2023-04-24 17:29:39 +00:00
johan bjorck
886c8ab408 make pgs global 2023-04-22 00:32:05 +00:00
Jonathan Rayner
691e843ed5
Bump timm version to latest 2023-04-11 11:02:21 +01:00
Shuming Ma
4ae3b248ee
Merge pull request #21 from microsoft/0.2.0
v0.2.0
2023-03-15 15:21:52 +08:00
Li Dong
36c0e55004
Update README.md 2023-03-13 21:40:20 +08:00
16 changed files with 1310 additions and 81 deletions

26
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: Python package
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f setup.py ]; then pip install .; fi
- name: Install pytest
run: |
pip install pytest
- name: Run tests
run: |
pytest tests/

View File

@ -1,4 +1,4 @@
# TorchScale - A Library for Transformers at (Any) Scale
# TorchScale - A Library of Foundation Architectures
<p>
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
@ -6,15 +6,20 @@
</p>
TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively.
It has the implementation of fundamental research to improve modeling generality and capability as well as training stability and efficiency of scaling Transformers.
Fundamental research to develop new architectures for foundation models and A(G)I, focusing on modeling generality and capability, as well as training stability and efficiency.
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal)
- Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
### Revolutionizing Transformers for (M)LLMs and AI
- [**RetNet**](https://arxiv.org/abs/2307.08621): Retentive Network: A Successor to Transformer for Large Language Models
- [**LongNet**](https://arxiv.org/abs/2307.02486): Scaling Transformers to 1,000,000,000 Tokens
## News
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
## Installation
@ -65,6 +70,20 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture
>>> print(encdec)
```
It takes only several lines of code to create a RetNet model:
```python
# Creating a RetNet model
>>> import torch
>>> from torchscale.architecture.config import RetNetConfig
>>> from torchscale.architecture.retnet import RetNetDecoder
>>> config = RetNetConfig(vocab_size=64000)
>>> retnet = RetNetDecoder(config)
>>> print(retnet)
```
## Key Features
- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
@ -93,6 +112,9 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture
- [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184)
* we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo.
- [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621)
* created by `config = RetNetConfig(vocab_size=64000)` and `retnet = RetNetDecoder(config)`.
Most of the features above can be used by simply passing the corresponding parameters to the config. For example:
```python
@ -107,7 +129,7 @@ Most of the features above can be used by simply passing the corresponding param
## Examples
We have the examples of how to use TorchScale in the following scenarios/tasks:
We have examples of how to use TorchScale in the following scenarios/tasks:
- Language
@ -125,7 +147,7 @@ We have the examples of how to use TorchScale in the following scenarios/tasks:
- Multimodal
* [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
* [Multiway Transformers/BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3)
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
@ -134,7 +156,7 @@ We plan to provide more examples regarding different tasks (e.g. vision pretrain
### Stability Evaluation
<p align="center">
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/convergence.png" width="800"/>
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/convergence.png?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D" width="800"/>
</p>
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
@ -142,7 +164,7 @@ The training curve is smooth by using TorchScale, while the baseline Transformer
### Scaling-up Experiments
<p align="center">
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/scaling_curve.png" width="800"/>
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/scaling_curve.png?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D" width="800"/>
</p>
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
@ -195,6 +217,16 @@ If you find this repository useful, please consider citing our work:
}
```
```
@article{retnet,
author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},
title = {Retentive Network: A Successor to {Transformer} for Large Language Models},
journal = {ArXiv},
volume = {abs/2307.08621},
year = {2023}
}
```
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
@ -206,13 +238,11 @@ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply fo
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
Any use of third-party trademarks or logos is subject to those third-party's policies.

View File

@ -65,7 +65,7 @@ Also, the JSON file should be in the format like this:
]
```
You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
```
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
```

View File

@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
from fairseq.modules import PositionalEmbedding
from torch import Tensor
from torchscale.architecture.config import DecoderConfig, EncoderConfig
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
from torchscale.architecture.encoder import Encoder
from .language_modeling import LMDecoder as MTDecoder
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
@classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
config = EncoderConfig()
config = EncoderDecoderConfig()
config.override(args)
return MTEncoder(
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary
):
config = DecoderConfig()
config = EncoderDecoderConfig()
config.override(args)
return MTDecoder(

View File

@ -0,0 +1,387 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import distributed_utils, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from omegaconf import II
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: str = field(
default="swish", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
relu_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_value_embed_dim: int = field(
default=864, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
decoder_input_dim: int = field(
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=864, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_retention_heads: int = field(
default=2, metadata={"help": "num decoder retention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply norm before each decoder block"}
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add norm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
},
)
moe_freq: int = field(
default=0,
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
)
moe_expert_count: int = field(
default=0, metadata={"help": "Number of experts in each MoE Layer"}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
)
moe_second_expert_policy: str = field(
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
)
moe_expert_ffn_dim: Optional[int] = field(
default=None, metadata={"help": "MoE expert FFN dimension"}
)
moe_top1_expert: Optional[bool] = field(
default=False, metadata={"help": "Use top1 gate instead of top2"}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
},
)
moe_normalize_expert_grad: Optional[str] = field(
default="world_size",
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
},
)
record_a2a_perf_stats: Optional[bool] = field(
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
)
dummy_a2a: Optional[bool] = field(
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
)
use_xmoe: Optional[bool] = field(
default=False,
)
chunkwise_recurrent: Optional[bool] = field(
default=False,
)
recurrent_chunk_size: Optional[int] = field(
default=512,
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("common.tpu")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
fp16: bool = II("common.fp16")
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
ddp_backend: str = II("distributed_training.ddp_backend")
world_size: int = II("distributed_training.distributed_world_size")
distributed_rank: int = II("distributed_training.distributed_rank")
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
@register_model("retnet", dataclass=LanguageConfig)
class RetNetLanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, "fp16", False)
and not getattr(args, "memory_efficient_fp16", False)
and getattr(args, "ddp_backend", None) != "fully_sharded"
):
assert (
args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = RetNetConfig()
config.override(args)
decoder = LMDecoder(
config,
embed_tokens,
output_projection,
dictionary=task.dictionary,
)
return cls(args, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return Embedding(len(dictionary), embed_dim, dictionary.pad())
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_target_positions
def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
@register_model_architecture("retnet", "retnet_base")
def retnet_base_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, "dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "swish")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False)
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
@register_model_architecture("retnet", "retnet_medium")
def retnet_medium(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728)
args.decoder_layers = getattr(args, "decoder_layers", 16)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_xl")
def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_3b")
def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_7b")
def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_13b")
def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_65b")
def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args)

View File

@ -17,7 +17,7 @@ setup(
license="MIT",
url="https://github.com/microsoft/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13"],
python_requires=">=3.8.0",
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -142,6 +142,7 @@ class EncoderDecoderConfig(object):
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.normalize_output = kwargs.pop("normalize_output", True)
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
@ -206,3 +207,77 @@ class EncoderDecoderConfig(object):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
class RetNetConfig(object):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
# Blockwise
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
# RetNet's RelPos base
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
# Backwards compatibility flags
self.use_layernorm = kwargs.pop("use_layernorm", False)
self.use_biases = kwargs.pop("use_biases", False)
self.use_glu = kwargs.pop("use_glu", True)
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False
if self.subln:
self.decoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
assert self.moe_freq > 0 and self.moe_expert_count > 0
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)

View File

@ -140,6 +140,7 @@ class DecoderLayer(nn.Module):
self_attn_padding_mask=None,
self_attn_rel_pos=None,
cross_attn_rel_pos=None,
is_first_step=False,
):
residual = x
if self.normalize_before:
@ -153,6 +154,7 @@ class DecoderLayer(nn.Module):
incremental_state=incremental_state,
attn_mask=self_attn_mask,
rel_pos=self_attn_rel_pos,
is_first_step=is_first_step,
)
x = self.dropout_module(x)
@ -357,7 +359,7 @@ class Decoder(nn.Module):
tokens, incremental_state=incremental_state
)
if incremental_state is not None:
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
@ -377,6 +379,11 @@ class Decoder(nn.Module):
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(
self,
prev_output_tokens,
@ -392,6 +399,7 @@ class Decoder(nn.Module):
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
is_first_step = self.is_first_step(incremental_state)
# relative position
self_attn_rel_pos_bias = None
@ -400,7 +408,7 @@ class Decoder(nn.Module):
self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(0), qlen=slen, klen=slen
)
if incremental_state is not None:
if incremental_state is not None and not is_first_step:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
cross_attn_rel_pos_bias = None
if self.cross_attn_relative_position is not None:
@ -409,7 +417,7 @@ class Decoder(nn.Module):
qlen=slen,
klen=encoder_out["encoder_out"].size(1),
)
if incremental_state is not None:
if incremental_state is not None and not is_first_step:
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
# decoder layers
@ -421,7 +429,7 @@ class Decoder(nn.Module):
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
for idx, layer in enumerate(self.layers):
if incremental_state is None:
if incremental_state is None or is_first_step:
self_attn_mask = torch.triu(
torch.zeros([x.size(1), x.size(1)])
.float()
@ -429,6 +437,9 @@ class Decoder(nn.Module):
.type_as(x),
1,
)
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
self_attn_mask = None
if idx not in incremental_state:
@ -445,6 +456,7 @@ class Decoder(nn.Module):
self_attn_padding_mask=self_attn_padding_mask,
self_attn_rel_pos=self_attn_rel_pos_bias,
cross_attn_rel_pos=cross_attn_rel_pos_bias,
is_first_step=is_first_step,
)
l_aux.append(l_aux_i)
inner_states.append(x)

View File

@ -0,0 +1,403 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
from torchscale.component.gate_linear_unit import GLU
from torchscale.component.multiscale_retention import MultiScaleRetention
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.component.rms_norm import RMSNorm
class RetNetRelPos(nn.Module):
def __init__(self, args):
super().__init__()
angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
self.register_buffer("angle", angle)
self.register_buffer("decay", decay)
self.recurrent_chunk_size = args.recurrent_chunk_size
def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
if activate_recurrent:
sin = torch.sin(self.angle * (slen - 1))
cos = torch.cos(self.angle * (slen - 1))
retention_rel_pos = ((sin, cos), self.decay.exp())
elif chunkwise_recurrent:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
mask = torch.tril(torch.ones(self.recurrent_chunk_size, self.recurrent_chunk_size).to(self.decay))
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
value_inner_decay = value_inner_decay.unsqueeze(-1)
scale = mask.sum(dim=-1, keepdim=True).sqrt()
inner_mask = mask / scale
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
cross_decay = cross_decay[:, None, None]
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
else:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
mask = torch.tril(torch.ones(slen, slen).to(self.decay))
mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
retention_rel_pos = ((sin, cos), mask)
return retention_rel_pos
class DecoderLayer(nn.Module):
def __init__(
self,
args,
depth,
is_moe_layer=False,
):
super().__init__()
self.args = args
self.embed_dim = args.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(args.dropout)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.retention = self.build_retention(self.embed_dim, args)
self.normalize_before = args.decoder_normalize_before
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
if not self.is_moe_layer:
self.ffn = self.build_ffn(
self.embed_dim,
self.args,
)
else:
if args.moe_top1_expert:
gate = Top1Gate(
self.embed_dim,
args.moe_expert_count,
use_fp32=args.moe_gating_use_fp32,
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
else:
gate = Top2Gate(
self.embed_dim,
args.moe_expert_count,
args.moe_gating_use_fp32,
args.moe_second_expert_policy,
args.moe_normalize_gate_prob_before_dropping,
args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
else:
self.alpha = 1.0
def build_ffn(self, embed_dim, args):
return GLU(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
) if args.use_glu else FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
def build_retention(self, embed_dim, args):
return MultiScaleRetention(
args,
embed_dim,
args.decoder_value_embed_dim,
args.decoder_retention_heads,
)
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
incremental_state=None,
chunkwise_recurrent=False,
retention_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.retention_layer_norm(x)
x = self.retention(
x,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.retention_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
if not self.is_moe_layer:
x = self.ffn(x)
l_aux = None
else:
x, l_aux = self.moe_layer(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, l_aux
class RetNetDecoder(nn.Module):
def __init__(
self,
args,
embed_tokens=None,
output_projection=None,
**kwargs
):
super().__init__(**kwargs)
self.args = args
self.dropout_module = torch.nn.Dropout(args.dropout)
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.embed_tokens = embed_tokens
if (
output_projection is None
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else:
self.layernorm_embedding = None
self.layers = nn.ModuleList([])
moe_freq = args.moe_freq
for i in range(args.decoder_layers):
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
self.layers.append(
self.build_decoder_layer(
args,
depth=i,
is_moe_layer=is_moe_layer,
)
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(args)
self.chunkwise_recurrent = args.chunkwise_recurrent
self.recurrent_chunk_size = args.recurrent_chunk_size
if args.deepnorm:
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
def build_output_projection(
self,
args,
):
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = self.embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
return output_projection
def build_decoder_layer(
self, args, depth, is_moe_layer=False
):
layer = DecoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer
def forward_embedding(
self,
tokens,
token_embedding=None,
incremental_state=None,
):
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if token_embedding is None:
token_embedding = self.embed_tokens(tokens)
x = embed = self.embed_scale * token_embedding
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(
self,
prev_output_tokens,
incremental_state=None,
features_only=False,
return_all_hiddens=False,
token_embeddings=None,
**kwargs
):
# embed tokens
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
is_first_step = self.is_first_step(incremental_state)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(1) % self.recurrent_chunk_size
slen = prev_output_tokens.size(1) + padding_len
x = F.pad(x, (0, 0, 0, padding_len))
else:
slen = prev_output_tokens.size(1)
# relative position
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
# decoder layers
inner_states = [x]
l_aux = []
for idx, layer in enumerate(self.layers):
if incremental_state is None or is_first_step:
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
if idx not in incremental_state:
incremental_state[idx] = {}
x, l_aux_i = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos,
chunkwise_recurrent=self.chunkwise_recurrent,
)
l_aux.append(l_aux_i)
inner_states.append(x)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
x = x[:, :prev_output_tokens.size(1), :]
if self.layer_norm is not None:
x = self.layer_norm(x)
if not features_only:
x = self.output_layer(x)
return x, {
"inner_states": inner_states,
"l_aux": l_aux,
"attn": None,
}
def output_layer(self, features):
return self.output_projection(features)

View File

@ -10,6 +10,9 @@ except ModuleNotFoundError:
from torch.nn import LayerNorm
from .xmoe.global_groups import get_moe_group
class set_torch_seed(object):
def __init__(self, seed):
assert isinstance(seed, int)
@ -70,7 +73,9 @@ def make_experts(args, embed_dim, expert_ffn_dim):
world_size % args.moe_expert_count == 0
), f"{world_size}, {args.moe_expert_count}"
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
moe_idx, _ = get_moe_group(args.moe_expert_count)
with set_torch_seed(start_seed + moe_idx):
expert_list.append(
FeedForwardNetwork(
embed_dim,
@ -91,6 +96,8 @@ def get_activation_fn(activation):
return F.relu
elif activation == "gelu":
return F.gelu
elif activation == "swish":
return F.silu
else:
raise NotImplementedError

View File

@ -0,0 +1,44 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
import torch.nn.functional as F
from .feedforward_network import get_activation_fn
class GLU(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
self.gate.reset_parameters()
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
g = self.gate(x)
x = self.fc1(x)
x = self.activation_fn(x.float()).type_as(x) * g
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = x.view(x_shape)
x = self.dropout_module(x)
return x

View File

@ -71,6 +71,7 @@ class MultiheadAttention(nn.Module):
key_padding_mask=None,
attn_mask=None,
rel_pos=None,
is_first_step=False,
):
bsz, tgt_len, embed_dim = query.size()
src_len = tgt_len
@ -112,7 +113,7 @@ class MultiheadAttention(nn.Module):
src_len = k.size(1)
if self.xpos is not None:
if incremental_state is not None:
if incremental_state is not None and not is_first_step:
offset = src_len - 1
else:
offset = 0

View File

@ -0,0 +1,210 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn.functional as F
from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .rms_norm import RMSNorm
from .multiway_network import MultiwayWrapper
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def theta_shift(x, sin, cos):
return (x * cos) + (rotate_every_two(x) * sin)
def get_activation_fn(activation):
if activation == "swish":
return F.silu
elif activation == "gelu":
return F.gelu
else:
raise NotImplementedError
class MultiScaleRetention(nn.Module):
def __init__(
self,
args,
embed_dim,
value_dim,
num_heads,
gate_fn="swish",
):
super().__init__()
self.args = args
self.embed_dim = embed_dim
self.value_dim = value_dim
self.num_heads = num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim ** -0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight)
if hasattr(self.out_proj, "bias"):
nn.init.constant_(self.out_proj.bias, 0.0)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size()
vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
qk_mat = qk_mat * mask
# invariant after normalization
qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)
output = torch.matmul(qk_mat, vr)
output = output.transpose(1, 2)
return output
def recurrent_forward(
self,
qr, kr, v,
decay,
incremental_state
):
bsz = v.size(0)
v = v.view(bsz, self.num_heads, self.head_dim, 1)
kv = kr * v
if "prev_key_value" in incremental_state:
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
incremental_state["prev_key_value"] = kv
incremental_state["scale"] = scale
output = torch.sum(qr * kv, dim=3)
return output
def chunk_recurrent_forward(
self,
qr, kr, v,
inner_mask
):
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
bsz, tgt_len, embed_dim = v.size()
chunk_len = mask.size(1)
num_chunks = tgt_len // chunk_len
assert tgt_len % chunk_len == 0
qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
kr_t = kr.transpose(-1, -2)
qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
qk_mat = qk_mat * mask
inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
qk_mat = qk_mat / inner_scale
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
# reduce kv in one chunk
kv = kr_t @ (v * value_inner_decay)
kv_recurrent = []
cross_scale = []
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
# accumulate kv by loop
for i in range(num_chunks):
kv_recurrent.append(kv_state / kv_scale)
cross_scale.append(kv_scale)
kv_state = kv_state * cross_decay + kv[:, i]
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
kv_recurrent = torch.stack(kv_recurrent, dim=1)
cross_scale = torch.stack(cross_scale, dim=1)
all_scale = torch.maximum(inner_scale, cross_scale)
align_inner_scale = all_scale / inner_scale
align_cross_scale = all_scale / cross_scale
cross_output = (qr * query_inner_decay) @ kv_recurrent
output = inner_output / align_inner_scale + cross_output / align_cross_scale
# output = inner_output / cross_scale + cross_output / inner_scale
output = output.transpose(2, 3)
return output
def forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
g = self.g_proj(x)
k *= self.scaling
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output

View File

@ -0,0 +1,25 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output

View File

@ -0,0 +1,65 @@
import torch.distributed as dist
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count=None):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_expert_count = moe_expert_count
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return my_group_idx, get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]

View File

@ -18,6 +18,8 @@ import torch.distributed as dist
from torch import Tensor
from torch.nn import Module, ModuleList
from .global_groups import get_all2all_group, get_moe_group
try:
from fairseq.modules.moe import MOELayer
@ -61,64 +63,6 @@ class _AllToAll(torch.autograd.Function):
return (None, _AllToAll.apply(ctx.group, *grad_output))
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]
class MOELayer(Base):
@ -149,7 +93,7 @@ class MOELayer(Base):
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.expert_group = get_moe_group(args.moe_expert_count)
_, self.expert_group = get_moe_group(args.moe_expert_count)
self.all2all_group = get_all2all_group(args.moe_expert_count)
self.world_size = dist.get_world_size(group=self.expert_group)
self.all2all_size = dist.get_world_size(group=self.all2all_group)