Compare commits

..

No commits in common. "main" and "0.2.0" have entirely different histories.
main ... 0.2.0

16 changed files with 81 additions and 1310 deletions

View File

@ -1,26 +0,0 @@
name: Python package
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f setup.py ]; then pip install .; fi
- name: Install pytest
run: |
pip install pytest
- name: Run tests
run: |
pytest tests/

View File

@ -1,4 +1,4 @@
# TorchScale - A Library of Foundation Architectures # TorchScale - A Library for Transformers at (Any) Scale
<p> <p>
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a> <a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
@ -6,20 +6,15 @@
</p> </p>
TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively. TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively.
It has the implementation of fundamental research to improve modeling generality and capability as well as training stability and efficiency of scaling Transformers.
Fundamental research to develop new architectures for foundation models and A(G)I, focusing on modeling generality and capability, as well as training stability and efficiency.
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond - Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal) - Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal)
- Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer - Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE) - Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
### Revolutionizing Transformers for (M)LLMs and AI
- [**RetNet**](https://arxiv.org/abs/2307.08621): Retentive Network: A Successor to Transformer for Large Language Models
- [**LongNet**](https://arxiv.org/abs/2307.02486): Scaling Transformers to 1,000,000,000 Tokens
## News ## News
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)] - November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
## Installation ## Installation
@ -70,20 +65,6 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture
>>> print(encdec) >>> print(encdec)
``` ```
It takes only several lines of code to create a RetNet model:
```python
# Creating a RetNet model
>>> import torch
>>> from torchscale.architecture.config import RetNetConfig
>>> from torchscale.architecture.retnet import RetNetDecoder
>>> config = RetNetConfig(vocab_size=64000)
>>> retnet = RetNetDecoder(config)
>>> print(retnet)
```
## Key Features ## Key Features
- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555) - [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
@ -112,9 +93,6 @@ It takes only several lines of code to create a RetNet model:
- [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184) - [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184)
* we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo. * we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo.
- [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621)
* created by `config = RetNetConfig(vocab_size=64000)` and `retnet = RetNetDecoder(config)`.
Most of the features above can be used by simply passing the corresponding parameters to the config. For example: Most of the features above can be used by simply passing the corresponding parameters to the config. For example:
```python ```python
@ -129,7 +107,7 @@ Most of the features above can be used by simply passing the corresponding param
## Examples ## Examples
We have examples of how to use TorchScale in the following scenarios/tasks: We have the examples of how to use TorchScale in the following scenarios/tasks:
- Language - Language
@ -147,7 +125,7 @@ We have examples of how to use TorchScale in the following scenarios/tasks:
- Multimodal - Multimodal
* [Multiway Transformers/BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3) * [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome! We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
@ -156,7 +134,7 @@ We plan to provide more examples regarding different tasks (e.g. vision pretrain
### Stability Evaluation ### Stability Evaluation
<p align="center"> <p align="center">
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/convergence.png?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D" width="800"/> <img src="https://publicmodel.blob.core.windows.net/torchscale/pic/convergence.png" width="800"/>
</p> </p>
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge. The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
@ -164,7 +142,7 @@ The training curve is smooth by using TorchScale, while the baseline Transformer
### Scaling-up Experiments ### Scaling-up Experiments
<p align="center"> <p align="center">
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/scaling_curve.png?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D" width="800"/> <img src="https://publicmodel.blob.core.windows.net/torchscale/pic/scaling_curve.png" width="800"/>
</p> </p>
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain. TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
@ -217,16 +195,6 @@ If you find this repository useful, please consider citing our work:
} }
``` ```
```
@article{retnet,
author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},
title = {Retentive Network: A Successor to {Transformer} for Large Language Models},
journal = {ArXiv},
volume = {abs/2307.08621},
year = {2023}
}
```
## Contributing ## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a This project welcomes contributions and suggestions. Most contributions require you to agree to a
@ -238,11 +206,13 @@ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply fo
provided by the bot. You will only need to do this once across all repos using our CLA. provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments. contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
## Trademarks ## Trademarks
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos is subject to those third-party's policies. Any use of third-party trademarks or logos are subject to those third-party's policies.

View File

@ -65,7 +65,7 @@ Also, the JSON file should be in the format like this:
] ]
``` ```
You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
``` ```
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
``` ```

View File

@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from torch import Tensor from torch import Tensor
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig from torchscale.architecture.config import DecoderConfig, EncoderConfig
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
from .language_modeling import LMDecoder as MTDecoder from .language_modeling import LMDecoder as MTDecoder
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary): def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
config = EncoderDecoderConfig() config = EncoderConfig()
config.override(args) config.override(args)
return MTEncoder( return MTEncoder(
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
def build_decoder( def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary cls, args, embed_tokens, embed_positions, output_projection, dictionary
): ):
config = EncoderDecoderConfig() config = DecoderConfig()
config.override(args) config.override(args)
return MTDecoder( return MTDecoder(

View File

@ -1,387 +0,0 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import distributed_utils, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from omegaconf import II
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: str = field(
default="swish", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
relu_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_value_embed_dim: int = field(
default=864, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
decoder_input_dim: int = field(
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=864, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_retention_heads: int = field(
default=2, metadata={"help": "num decoder retention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply norm before each decoder block"}
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add norm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
},
)
moe_freq: int = field(
default=0,
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
)
moe_expert_count: int = field(
default=0, metadata={"help": "Number of experts in each MoE Layer"}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
)
moe_second_expert_policy: str = field(
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
)
moe_expert_ffn_dim: Optional[int] = field(
default=None, metadata={"help": "MoE expert FFN dimension"}
)
moe_top1_expert: Optional[bool] = field(
default=False, metadata={"help": "Use top1 gate instead of top2"}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
},
)
moe_normalize_expert_grad: Optional[str] = field(
default="world_size",
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
},
)
record_a2a_perf_stats: Optional[bool] = field(
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
)
dummy_a2a: Optional[bool] = field(
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
)
use_xmoe: Optional[bool] = field(
default=False,
)
chunkwise_recurrent: Optional[bool] = field(
default=False,
)
recurrent_chunk_size: Optional[int] = field(
default=512,
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("common.tpu")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
fp16: bool = II("common.fp16")
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
ddp_backend: str = II("distributed_training.ddp_backend")
world_size: int = II("distributed_training.distributed_world_size")
distributed_rank: int = II("distributed_training.distributed_rank")
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
@register_model("retnet", dataclass=LanguageConfig)
class RetNetLanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, "fp16", False)
and not getattr(args, "memory_efficient_fp16", False)
and getattr(args, "ddp_backend", None) != "fully_sharded"
):
assert (
args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = RetNetConfig()
config.override(args)
decoder = LMDecoder(
config,
embed_tokens,
output_projection,
dictionary=task.dictionary,
)
return cls(args, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return Embedding(len(dictionary), embed_dim, dictionary.pad())
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_target_positions
def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
@register_model_architecture("retnet", "retnet_base")
def retnet_base_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, "dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "swish")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False)
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
@register_model_architecture("retnet", "retnet_medium")
def retnet_medium(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728)
args.decoder_layers = getattr(args, "decoder_layers", 16)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_xl")
def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_3b")
def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_7b")
def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_13b")
def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_65b")
def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args)

View File

@ -17,7 +17,7 @@ setup(
license="MIT", license="MIT",
url="https://github.com/microsoft/torchscale", url="https://github.com/microsoft/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13"], install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
python_requires=">=3.8.0", python_requires=">=3.8.0",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",

View File

@ -142,7 +142,6 @@ class EncoderDecoderConfig(object):
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
self.encoder_layers = kwargs.pop("encoder_layers", 12) self.encoder_layers = kwargs.pop("encoder_layers", 12)
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
self.normalize_output = kwargs.pop("normalize_output", True)
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
@ -207,77 +206,3 @@ class EncoderDecoderConfig(object):
for hp in self.__dict__.keys(): for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None: if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None) self.__dict__[hp] = getattr(args, hp, None)
class RetNetConfig(object):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
# Blockwise
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
# RetNet's RelPos base
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
# Backwards compatibility flags
self.use_layernorm = kwargs.pop("use_layernorm", False)
self.use_biases = kwargs.pop("use_biases", False)
self.use_glu = kwargs.pop("use_glu", True)
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False
if self.subln:
self.decoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
assert self.moe_freq > 0 and self.moe_expert_count > 0
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)

View File

@ -140,7 +140,6 @@ class DecoderLayer(nn.Module):
self_attn_padding_mask=None, self_attn_padding_mask=None,
self_attn_rel_pos=None, self_attn_rel_pos=None,
cross_attn_rel_pos=None, cross_attn_rel_pos=None,
is_first_step=False,
): ):
residual = x residual = x
if self.normalize_before: if self.normalize_before:
@ -154,7 +153,6 @@ class DecoderLayer(nn.Module):
incremental_state=incremental_state, incremental_state=incremental_state,
attn_mask=self_attn_mask, attn_mask=self_attn_mask,
rel_pos=self_attn_rel_pos, rel_pos=self_attn_rel_pos,
is_first_step=is_first_step,
) )
x = self.dropout_module(x) x = self.dropout_module(x)
@ -359,7 +357,7 @@ class Decoder(nn.Module):
tokens, incremental_state=incremental_state tokens, incremental_state=incremental_state
) )
if incremental_state is not None and not self.is_first_step(incremental_state): if incremental_state is not None:
tokens = tokens[:, -1:] tokens = tokens[:, -1:]
if positions is not None: if positions is not None:
positions = positions[:, -1:] positions = positions[:, -1:]
@ -379,11 +377,6 @@ class Decoder(nn.Module):
return x, embed return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward( def forward(
self, self,
prev_output_tokens, prev_output_tokens,
@ -399,7 +392,6 @@ class Decoder(nn.Module):
x, _ = self.forward_embedding( x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state prev_output_tokens, token_embeddings, incremental_state
) )
is_first_step = self.is_first_step(incremental_state)
# relative position # relative position
self_attn_rel_pos_bias = None self_attn_rel_pos_bias = None
@ -408,7 +400,7 @@ class Decoder(nn.Module):
self_attn_rel_pos_bias = self.self_attn_relative_position( self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(0), qlen=slen, klen=slen batch_size=x.size(0), qlen=slen, klen=slen
) )
if incremental_state is not None and not is_first_step: if incremental_state is not None:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :] self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
cross_attn_rel_pos_bias = None cross_attn_rel_pos_bias = None
if self.cross_attn_relative_position is not None: if self.cross_attn_relative_position is not None:
@ -417,7 +409,7 @@ class Decoder(nn.Module):
qlen=slen, qlen=slen,
klen=encoder_out["encoder_out"].size(1), klen=encoder_out["encoder_out"].size(1),
) )
if incremental_state is not None and not is_first_step: if incremental_state is not None:
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :] cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
# decoder layers # decoder layers
@ -429,7 +421,7 @@ class Decoder(nn.Module):
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else [] l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
if incremental_state is None or is_first_step: if incremental_state is None:
self_attn_mask = torch.triu( self_attn_mask = torch.triu(
torch.zeros([x.size(1), x.size(1)]) torch.zeros([x.size(1), x.size(1)])
.float() .float()
@ -437,9 +429,6 @@ class Decoder(nn.Module):
.type_as(x), .type_as(x),
1, 1,
) )
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else: else:
self_attn_mask = None self_attn_mask = None
if idx not in incremental_state: if idx not in incremental_state:
@ -456,7 +445,6 @@ class Decoder(nn.Module):
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
self_attn_rel_pos=self_attn_rel_pos_bias, self_attn_rel_pos=self_attn_rel_pos_bias,
cross_attn_rel_pos=cross_attn_rel_pos_bias, cross_attn_rel_pos=cross_attn_rel_pos_bias,
is_first_step=is_first_step,
) )
l_aux.append(l_aux_i) l_aux.append(l_aux_i)
inner_states.append(x) inner_states.append(x)

View File

@ -1,403 +0,0 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
from torchscale.component.gate_linear_unit import GLU
from torchscale.component.multiscale_retention import MultiScaleRetention
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.component.rms_norm import RMSNorm
class RetNetRelPos(nn.Module):
def __init__(self, args):
super().__init__()
angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
self.register_buffer("angle", angle)
self.register_buffer("decay", decay)
self.recurrent_chunk_size = args.recurrent_chunk_size
def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
if activate_recurrent:
sin = torch.sin(self.angle * (slen - 1))
cos = torch.cos(self.angle * (slen - 1))
retention_rel_pos = ((sin, cos), self.decay.exp())
elif chunkwise_recurrent:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
mask = torch.tril(torch.ones(self.recurrent_chunk_size, self.recurrent_chunk_size).to(self.decay))
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
value_inner_decay = value_inner_decay.unsqueeze(-1)
scale = mask.sum(dim=-1, keepdim=True).sqrt()
inner_mask = mask / scale
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
cross_decay = cross_decay[:, None, None]
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
else:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
mask = torch.tril(torch.ones(slen, slen).to(self.decay))
mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
retention_rel_pos = ((sin, cos), mask)
return retention_rel_pos
class DecoderLayer(nn.Module):
def __init__(
self,
args,
depth,
is_moe_layer=False,
):
super().__init__()
self.args = args
self.embed_dim = args.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(args.dropout)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.retention = self.build_retention(self.embed_dim, args)
self.normalize_before = args.decoder_normalize_before
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
if not self.is_moe_layer:
self.ffn = self.build_ffn(
self.embed_dim,
self.args,
)
else:
if args.moe_top1_expert:
gate = Top1Gate(
self.embed_dim,
args.moe_expert_count,
use_fp32=args.moe_gating_use_fp32,
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
else:
gate = Top2Gate(
self.embed_dim,
args.moe_expert_count,
args.moe_gating_use_fp32,
args.moe_second_expert_policy,
args.moe_normalize_gate_prob_before_dropping,
args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
else:
self.alpha = 1.0
def build_ffn(self, embed_dim, args):
return GLU(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
) if args.use_glu else FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
def build_retention(self, embed_dim, args):
return MultiScaleRetention(
args,
embed_dim,
args.decoder_value_embed_dim,
args.decoder_retention_heads,
)
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
incremental_state=None,
chunkwise_recurrent=False,
retention_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.retention_layer_norm(x)
x = self.retention(
x,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.retention_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
if not self.is_moe_layer:
x = self.ffn(x)
l_aux = None
else:
x, l_aux = self.moe_layer(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, l_aux
class RetNetDecoder(nn.Module):
def __init__(
self,
args,
embed_tokens=None,
output_projection=None,
**kwargs
):
super().__init__(**kwargs)
self.args = args
self.dropout_module = torch.nn.Dropout(args.dropout)
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.embed_tokens = embed_tokens
if (
output_projection is None
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else:
self.layernorm_embedding = None
self.layers = nn.ModuleList([])
moe_freq = args.moe_freq
for i in range(args.decoder_layers):
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
self.layers.append(
self.build_decoder_layer(
args,
depth=i,
is_moe_layer=is_moe_layer,
)
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(args)
self.chunkwise_recurrent = args.chunkwise_recurrent
self.recurrent_chunk_size = args.recurrent_chunk_size
if args.deepnorm:
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
def build_output_projection(
self,
args,
):
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = self.embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
return output_projection
def build_decoder_layer(
self, args, depth, is_moe_layer=False
):
layer = DecoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer
def forward_embedding(
self,
tokens,
token_embedding=None,
incremental_state=None,
):
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if token_embedding is None:
token_embedding = self.embed_tokens(tokens)
x = embed = self.embed_scale * token_embedding
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(
self,
prev_output_tokens,
incremental_state=None,
features_only=False,
return_all_hiddens=False,
token_embeddings=None,
**kwargs
):
# embed tokens
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
is_first_step = self.is_first_step(incremental_state)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(1) % self.recurrent_chunk_size
slen = prev_output_tokens.size(1) + padding_len
x = F.pad(x, (0, 0, 0, padding_len))
else:
slen = prev_output_tokens.size(1)
# relative position
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
# decoder layers
inner_states = [x]
l_aux = []
for idx, layer in enumerate(self.layers):
if incremental_state is None or is_first_step:
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
if idx not in incremental_state:
incremental_state[idx] = {}
x, l_aux_i = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos,
chunkwise_recurrent=self.chunkwise_recurrent,
)
l_aux.append(l_aux_i)
inner_states.append(x)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
x = x[:, :prev_output_tokens.size(1), :]
if self.layer_norm is not None:
x = self.layer_norm(x)
if not features_only:
x = self.output_layer(x)
return x, {
"inner_states": inner_states,
"l_aux": l_aux,
"attn": None,
}
def output_layer(self, features):
return self.output_projection(features)

View File

@ -10,9 +10,6 @@ except ModuleNotFoundError:
from torch.nn import LayerNorm from torch.nn import LayerNorm
from .xmoe.global_groups import get_moe_group
class set_torch_seed(object): class set_torch_seed(object):
def __init__(self, seed): def __init__(self, seed):
assert isinstance(seed, int) assert isinstance(seed, int)
@ -73,9 +70,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
world_size % args.moe_expert_count == 0 world_size % args.moe_expert_count == 0
), f"{world_size}, {args.moe_expert_count}" ), f"{world_size}, {args.moe_expert_count}"
moe_idx, _ = get_moe_group(args.moe_expert_count) with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
with set_torch_seed(start_seed + moe_idx):
expert_list.append( expert_list.append(
FeedForwardNetwork( FeedForwardNetwork(
embed_dim, embed_dim,
@ -96,8 +91,6 @@ def get_activation_fn(activation):
return F.relu return F.relu
elif activation == "gelu": elif activation == "gelu":
return F.gelu return F.gelu
elif activation == "swish":
return F.silu
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -1,44 +0,0 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
import torch.nn.functional as F
from .feedforward_network import get_activation_fn
class GLU(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
self.gate.reset_parameters()
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
g = self.gate(x)
x = self.fc1(x)
x = self.activation_fn(x.float()).type_as(x) * g
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = x.view(x_shape)
x = self.dropout_module(x)
return x

View File

@ -71,7 +71,6 @@ class MultiheadAttention(nn.Module):
key_padding_mask=None, key_padding_mask=None,
attn_mask=None, attn_mask=None,
rel_pos=None, rel_pos=None,
is_first_step=False,
): ):
bsz, tgt_len, embed_dim = query.size() bsz, tgt_len, embed_dim = query.size()
src_len = tgt_len src_len = tgt_len
@ -113,7 +112,7 @@ class MultiheadAttention(nn.Module):
src_len = k.size(1) src_len = k.size(1)
if self.xpos is not None: if self.xpos is not None:
if incremental_state is not None and not is_first_step: if incremental_state is not None:
offset = src_len - 1 offset = src_len - 1
else: else:
offset = 0 offset = 0

View File

@ -1,210 +0,0 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn.functional as F
from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .rms_norm import RMSNorm
from .multiway_network import MultiwayWrapper
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def theta_shift(x, sin, cos):
return (x * cos) + (rotate_every_two(x) * sin)
def get_activation_fn(activation):
if activation == "swish":
return F.silu
elif activation == "gelu":
return F.gelu
else:
raise NotImplementedError
class MultiScaleRetention(nn.Module):
def __init__(
self,
args,
embed_dim,
value_dim,
num_heads,
gate_fn="swish",
):
super().__init__()
self.args = args
self.embed_dim = embed_dim
self.value_dim = value_dim
self.num_heads = num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim ** -0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight)
if hasattr(self.out_proj, "bias"):
nn.init.constant_(self.out_proj.bias, 0.0)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size()
vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
qk_mat = qk_mat * mask
# invariant after normalization
qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)
output = torch.matmul(qk_mat, vr)
output = output.transpose(1, 2)
return output
def recurrent_forward(
self,
qr, kr, v,
decay,
incremental_state
):
bsz = v.size(0)
v = v.view(bsz, self.num_heads, self.head_dim, 1)
kv = kr * v
if "prev_key_value" in incremental_state:
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
incremental_state["prev_key_value"] = kv
incremental_state["scale"] = scale
output = torch.sum(qr * kv, dim=3)
return output
def chunk_recurrent_forward(
self,
qr, kr, v,
inner_mask
):
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
bsz, tgt_len, embed_dim = v.size()
chunk_len = mask.size(1)
num_chunks = tgt_len // chunk_len
assert tgt_len % chunk_len == 0
qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
kr_t = kr.transpose(-1, -2)
qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
qk_mat = qk_mat * mask
inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
qk_mat = qk_mat / inner_scale
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
# reduce kv in one chunk
kv = kr_t @ (v * value_inner_decay)
kv_recurrent = []
cross_scale = []
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
# accumulate kv by loop
for i in range(num_chunks):
kv_recurrent.append(kv_state / kv_scale)
cross_scale.append(kv_scale)
kv_state = kv_state * cross_decay + kv[:, i]
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
kv_recurrent = torch.stack(kv_recurrent, dim=1)
cross_scale = torch.stack(cross_scale, dim=1)
all_scale = torch.maximum(inner_scale, cross_scale)
align_inner_scale = all_scale / inner_scale
align_cross_scale = all_scale / cross_scale
cross_output = (qr * query_inner_decay) @ kv_recurrent
output = inner_output / align_inner_scale + cross_output / align_cross_scale
# output = inner_output / cross_scale + cross_output / inner_scale
output = output.transpose(2, 3)
return output
def forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
g = self.g_proj(x)
k *= self.scaling
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output

View File

@ -1,25 +0,0 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output

View File

@ -1,65 +0,0 @@
import torch.distributed as dist
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count=None):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_expert_count = moe_expert_count
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return my_group_idx, get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]

View File

@ -18,8 +18,6 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
from .global_groups import get_all2all_group, get_moe_group
try: try:
from fairseq.modules.moe import MOELayer from fairseq.modules.moe import MOELayer
@ -63,6 +61,64 @@ class _AllToAll(torch.autograd.Function):
return (None, _AllToAll.apply(ctx.group, *grad_output)) return (None, _AllToAll.apply(ctx.group, *grad_output))
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]
class MOELayer(Base): class MOELayer(Base):
@ -93,7 +149,7 @@ class MOELayer(Base):
self.experts = cast(ModuleList, experts) self.experts = cast(ModuleList, experts)
else: else:
self.experts = ModuleList([experts]) self.experts = ModuleList([experts])
_, self.expert_group = get_moe_group(args.moe_expert_count) self.expert_group = get_moe_group(args.moe_expert_count)
self.all2all_group = get_all2all_group(args.moe_expert_count) self.all2all_group = get_all2all_group(args.moe_expert_count)
self.world_size = dist.get_world_size(group=self.expert_group) self.world_size = dist.get_world_size(group=self.expert_group)
self.all2all_size = dist.get_world_size(group=self.all2all_group) self.all2all_size = dist.get_world_size(group=self.all2all_group)