Compare commits
No commits in common. "main" and "0.2.0" have entirely different histories.
26
.github/workflows/test.yml
vendored
26
.github/workflows/test.yml
vendored
|
@ -1,26 +0,0 @@
|
||||||
name: Python package
|
|
||||||
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Set up Python 3.10
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: "3.10"
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
|
||||||
if [ -f setup.py ]; then pip install .; fi
|
|
||||||
- name: Install pytest
|
|
||||||
run: |
|
|
||||||
pip install pytest
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
pytest tests/
|
|
52
README.md
52
README.md
|
@ -1,4 +1,4 @@
|
||||||
# TorchScale - A Library of Foundation Architectures
|
# TorchScale - A Library for Transformers at (Any) Scale
|
||||||
|
|
||||||
<p>
|
<p>
|
||||||
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
<a href="https://github.com/microsoft/torchscale/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
|
||||||
|
@ -6,20 +6,15 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively.
|
TorchScale is a PyTorch library that allows researchers and developers to scale up Transformers efficiently and effectively.
|
||||||
|
It has the implementation of fundamental research to improve modeling generality and capability as well as training stability and efficiency of scaling Transformers.
|
||||||
|
|
||||||
Fundamental research to develop new architectures for foundation models and A(G)I, focusing on modeling generality and capability, as well as training stability and efficiency.
|
|
||||||
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
|
- Stability - [**DeepNet**](https://arxiv.org/abs/2203.00555): scaling Transformers to 1,000 Layers and beyond
|
||||||
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal)
|
- Generality - [**Foundation Transformers (Magneto)**](https://arxiv.org/abs/2210.06423): towards true general-purpose modeling across tasks and modalities (including language, vision, speech, and multimodal)
|
||||||
- Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer
|
- Capability - A [**Length-Extrapolatable**](https://arxiv.org/abs/2212.10554) Transformer
|
||||||
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
|
- Efficiency - [**X-MoE**](https://arxiv.org/abs/2204.09179): scalable & finetunable sparse Mixture-of-Experts (MoE)
|
||||||
|
|
||||||
### Revolutionizing Transformers for (M)LLMs and AI
|
|
||||||
- [**RetNet**](https://arxiv.org/abs/2307.08621): Retentive Network: A Successor to Transformer for Large Language Models
|
|
||||||
- [**LongNet**](https://arxiv.org/abs/2307.02486): Scaling Transformers to 1,000,000,000 Tokens
|
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
|
|
||||||
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
|
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
@ -70,20 +65,6 @@ We also support the `Decoder` architecture and the `EncoderDecoder` architecture
|
||||||
>>> print(encdec)
|
>>> print(encdec)
|
||||||
```
|
```
|
||||||
|
|
||||||
It takes only several lines of code to create a RetNet model:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Creating a RetNet model
|
|
||||||
>>> import torch
|
|
||||||
>>> from torchscale.architecture.config import RetNetConfig
|
|
||||||
>>> from torchscale.architecture.retnet import RetNetDecoder
|
|
||||||
|
|
||||||
>>> config = RetNetConfig(vocab_size=64000)
|
|
||||||
>>> retnet = RetNetDecoder(config)
|
|
||||||
|
|
||||||
>>> print(retnet)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key Features
|
## Key Features
|
||||||
|
|
||||||
- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
|
- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
|
||||||
|
@ -112,9 +93,6 @@ It takes only several lines of code to create a RetNet model:
|
||||||
- [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184)
|
- [SparseClip: improving the gradient clipping for sparse MoE models](https://arxiv.org/abs/2211.13184)
|
||||||
* we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo.
|
* we provide a [sample code](examples/fairseq/utils/sparse_clip.py) that can be easily adapted to the FairSeq (or other) repo.
|
||||||
|
|
||||||
- [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621)
|
|
||||||
* created by `config = RetNetConfig(vocab_size=64000)` and `retnet = RetNetDecoder(config)`.
|
|
||||||
|
|
||||||
Most of the features above can be used by simply passing the corresponding parameters to the config. For example:
|
Most of the features above can be used by simply passing the corresponding parameters to the config. For example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -129,7 +107,7 @@ Most of the features above can be used by simply passing the corresponding param
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
We have examples of how to use TorchScale in the following scenarios/tasks:
|
We have the examples of how to use TorchScale in the following scenarios/tasks:
|
||||||
|
|
||||||
- Language
|
- Language
|
||||||
|
|
||||||
|
@ -147,7 +125,7 @@ We have examples of how to use TorchScale in the following scenarios/tasks:
|
||||||
|
|
||||||
- Multimodal
|
- Multimodal
|
||||||
|
|
||||||
* [Multiway Transformers/BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3)
|
* [Multiway Transformers/BEiT-3](torchscale/model/BEiT3.py) [In progress]
|
||||||
|
|
||||||
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
|
We plan to provide more examples regarding different tasks (e.g. vision pretraining and speech recognition) and various deep learning toolkits (e.g. [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)). Any comments or PRs are welcome!
|
||||||
|
|
||||||
|
@ -156,7 +134,7 @@ We plan to provide more examples regarding different tasks (e.g. vision pretrain
|
||||||
### Stability Evaluation
|
### Stability Evaluation
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/convergence.png?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D" width="800"/>
|
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/convergence.png" width="800"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
|
The training curve is smooth by using TorchScale, while the baseline Transformer cannot converge.
|
||||||
|
@ -164,7 +142,7 @@ The training curve is smooth by using TorchScale, while the baseline Transformer
|
||||||
### Scaling-up Experiments
|
### Scaling-up Experiments
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/scaling_curve.png?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D" width="800"/>
|
<img src="https://publicmodel.blob.core.windows.net/torchscale/pic/scaling_curve.png" width="800"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
|
TorchScale supports arbitrary depths and widths, successfully scaling-up the models without pain.
|
||||||
|
@ -217,16 +195,6 @@ If you find this repository useful, please consider citing our work:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
|
||||||
@article{retnet,
|
|
||||||
author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},
|
|
||||||
title = {Retentive Network: A Successor to {Transformer} for Large Language Models},
|
|
||||||
journal = {ArXiv},
|
|
||||||
volume = {abs/2307.08621},
|
|
||||||
year = {2023}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||||
|
@ -238,11 +206,13 @@ a CLA and decorate the PR appropriately (e.g., status check, comment). Simply fo
|
||||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||||
|
|
||||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||||
For more information, see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||||
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
|
contact [Furu Wei](mailto:fuwei@microsoft.com) and [Shuming Ma](mailto:shumma@microsoft.com) with any additional questions or comments.
|
||||||
|
|
||||||
## Trademarks
|
## Trademarks
|
||||||
|
|
||||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||||
|
trademarks or logos is subject to and must follow
|
||||||
|
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||||
Any use of third-party trademarks or logos is subject to those third-party's policies.
|
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
||||||
|
|
|
@ -65,7 +65,7 @@ Also, the JSON file should be in the format like this:
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt?sv=2020-04-08&st=2023-08-11T03%3A09%3A09Z&se=2053-08-12T03%3A09%3A00Z&sr=c&sp=rl&sig=3b6nDda%2Fu0vD6E%2BhoTO%2BHfNSnSlUfgvXFV%2FCNKquWjE%3D). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
|
You can quickly get started with our processed vocabulary files: [sentencepiece.bpe.model](https://publicmodel.blob.core.windows.net/torchscale/vocab/sentencepiece.bpe.model) and [dict.txt](https://publicmodel.blob.core.windows.net/torchscale/vocab/dict.txt). Note that this vocabulary is English-only with 64K tokens. To train a new `sentencepiece.bpe.model` on your own data, please refer to the [SentencePiece](https://github.com/google/sentencepiece) repo. With the sentecepiece model and the installed `sentencepiece` library, you can extract the `dict.txt` file from it by
|
||||||
```
|
```
|
||||||
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
|
spm_export_vocab --model=sentencepiece.bpe.model | sed 's/\t/ /g' | tail -n +4 > dict.txt
|
||||||
```
|
```
|
||||||
|
|
|
@ -22,7 +22,7 @@ from fairseq.models.transformer import Embedding
|
||||||
from fairseq.modules import PositionalEmbedding
|
from fairseq.modules import PositionalEmbedding
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
|
from torchscale.architecture.config import DecoderConfig, EncoderConfig
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
|
||||||
from .language_modeling import LMDecoder as MTDecoder
|
from .language_modeling import LMDecoder as MTDecoder
|
||||||
|
@ -308,7 +308,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
||||||
config = EncoderDecoderConfig()
|
config = EncoderConfig()
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
return MTEncoder(
|
return MTEncoder(
|
||||||
|
@ -323,7 +323,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
def build_decoder(
|
def build_decoder(
|
||||||
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
||||||
):
|
):
|
||||||
config = EncoderDecoderConfig()
|
config = DecoderConfig()
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
return MTDecoder(
|
return MTDecoder(
|
||||||
|
|
|
@ -1,387 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the MIT license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from fairseq import distributed_utils, utils
|
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
|
||||||
from fairseq.models import (
|
|
||||||
FairseqIncrementalDecoder,
|
|
||||||
FairseqLanguageModel,
|
|
||||||
register_model,
|
|
||||||
register_model_architecture,
|
|
||||||
)
|
|
||||||
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
|
||||||
from omegaconf import II
|
|
||||||
|
|
||||||
from torchscale.architecture.config import RetNetConfig
|
|
||||||
from torchscale.architecture.retnet import RetNetDecoder
|
|
||||||
|
|
||||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LanguageConfig(FairseqDataclass):
|
|
||||||
activation_fn: str = field(
|
|
||||||
default="swish", metadata={"help": "activation function to use"}
|
|
||||||
)
|
|
||||||
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
|
||||||
activation_dropout: float = field(
|
|
||||||
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
|
||||||
)
|
|
||||||
relu_dropout: float = field(
|
|
||||||
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
|
|
||||||
)
|
|
||||||
decoder_embed_dim: int = field(
|
|
||||||
default=512, metadata={"help": "decoder embedding dimension"}
|
|
||||||
)
|
|
||||||
decoder_value_embed_dim: int = field(
|
|
||||||
default=864, metadata={"help": "decoder embedding dimension"}
|
|
||||||
)
|
|
||||||
decoder_output_dim: int = field(
|
|
||||||
default=512, metadata={"help": "decoder output dimension"}
|
|
||||||
)
|
|
||||||
decoder_input_dim: int = field(
|
|
||||||
default=512, metadata={"help": "decoder input dimension"}
|
|
||||||
)
|
|
||||||
decoder_ffn_embed_dim: int = field(
|
|
||||||
default=864, metadata={"help": "decoder embedding dimension for FFN"}
|
|
||||||
)
|
|
||||||
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
|
|
||||||
decoder_retention_heads: int = field(
|
|
||||||
default=2, metadata={"help": "num decoder retention heads"}
|
|
||||||
)
|
|
||||||
decoder_normalize_before: bool = field(
|
|
||||||
default=False, metadata={"help": "apply norm before each decoder block"}
|
|
||||||
)
|
|
||||||
share_decoder_input_output_embed: bool = field(
|
|
||||||
default=False, metadata={"help": "share decoder input and output embeddings"}
|
|
||||||
)
|
|
||||||
decoder_learned_pos: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "use learned positional embeddings in the decoder"},
|
|
||||||
)
|
|
||||||
layernorm_embedding: bool = field(
|
|
||||||
default=False, metadata={"help": "add norm to embedding"}
|
|
||||||
)
|
|
||||||
no_scale_embedding: bool = field(
|
|
||||||
default=False, metadata={"help": "if True, dont scale embeddings"}
|
|
||||||
)
|
|
||||||
checkpoint_activations: bool = field(
|
|
||||||
default=False, metadata={"help": "checkpoint activations at each layer"}
|
|
||||||
)
|
|
||||||
offload_activations: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "move checkpointed activations to CPU after they are used."},
|
|
||||||
)
|
|
||||||
# config for Fully Sharded Data Parallel (FSDP) training
|
|
||||||
min_params_to_wrap: int = field(
|
|
||||||
default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
|
||||||
metadata={
|
|
||||||
"help": (
|
|
||||||
"minimum number of params for a layer to be wrapped with FSDP() when "
|
|
||||||
"training with --ddp-backend=fully_sharded. Smaller values will "
|
|
||||||
"improve memory efficiency, but may make torch.distributed "
|
|
||||||
"communication less efficient due to smaller input sizes. This option "
|
|
||||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
|
||||||
"--offload-activations are passed."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_freq: int = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
|
||||||
)
|
|
||||||
moe_expert_count: int = field(
|
|
||||||
default=0, metadata={"help": "Number of experts in each MoE Layer"}
|
|
||||||
)
|
|
||||||
moe_gating_use_fp32: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
|
|
||||||
)
|
|
||||||
moe_second_expert_policy: str = field(
|
|
||||||
default="sampling",
|
|
||||||
metadata={"help": "policy for second expert, options: all/sampling/random"},
|
|
||||||
)
|
|
||||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_expert_ffn_dim: Optional[int] = field(
|
|
||||||
default=None, metadata={"help": "MoE expert FFN dimension"}
|
|
||||||
)
|
|
||||||
moe_top1_expert: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Use top1 gate instead of top2"}
|
|
||||||
)
|
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
|
||||||
default=0.25,
|
|
||||||
metadata={
|
|
||||||
"help": (
|
|
||||||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
|
||||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
|
||||||
default="world_size",
|
|
||||||
metadata={
|
|
||||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
record_a2a_perf_stats: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "records all to all perf stats during distributed training"},
|
|
||||||
)
|
|
||||||
dummy_a2a: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "By passes all to all during distributed training by returning the input buffer as output"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "if true orders token by the gate prob before capacity dropping."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
use_xmoe: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
chunkwise_recurrent: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
recurrent_chunk_size: Optional[int] = field(
|
|
||||||
default=512,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# options from other parts of the config
|
|
||||||
add_bos_token: bool = II("task.add_bos_token")
|
|
||||||
tokens_per_sample: int = II("task.tokens_per_sample")
|
|
||||||
max_target_positions: Optional[int] = II("task.max_target_positions")
|
|
||||||
tpu: bool = II("common.tpu")
|
|
||||||
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
|
|
||||||
fp16: bool = II("common.fp16")
|
|
||||||
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
|
|
||||||
ddp_backend: str = II("distributed_training.ddp_backend")
|
|
||||||
world_size: int = II("distributed_training.distributed_world_size")
|
|
||||||
distributed_rank: int = II("distributed_training.distributed_rank")
|
|
||||||
ddp_rank: int = II("distributed_training.distributed_rank")
|
|
||||||
deepnorm: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
subln: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model("retnet", dataclass=LanguageConfig)
|
|
||||||
class RetNetLanguageModel(FairseqLanguageModel):
|
|
||||||
def __init__(self, args, decoder):
|
|
||||||
self.args = args
|
|
||||||
super().__init__(decoder)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_model(cls, args, task):
|
|
||||||
|
|
||||||
if getattr(args, "max_target_positions", None) is None:
|
|
||||||
args.max_target_positions = getattr(
|
|
||||||
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
embed_tokens = cls.build_embedding(
|
|
||||||
args, task.source_dictionary, args.decoder_embed_dim
|
|
||||||
)
|
|
||||||
if args.share_decoder_input_output_embed:
|
|
||||||
output_projection = torch.nn.Linear(
|
|
||||||
embed_tokens.weight.shape[1],
|
|
||||||
embed_tokens.weight.shape[0],
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
output_projection.weight = embed_tokens.weight
|
|
||||||
else:
|
|
||||||
output_projection = torch.nn.Linear(
|
|
||||||
args.decoder_embed_dim, len(task.dictionary), bias=False
|
|
||||||
)
|
|
||||||
torch.nn.init.normal_(
|
|
||||||
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(args, "moe_freq", 0) > 0 and (
|
|
||||||
getattr(args, "fp16", False)
|
|
||||||
and not getattr(args, "memory_efficient_fp16", False)
|
|
||||||
and getattr(args, "ddp_backend", None) != "fully_sharded"
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
args.fp16_no_flatten_grads
|
|
||||||
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
|
||||||
|
|
||||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
|
||||||
|
|
||||||
config = RetNetConfig()
|
|
||||||
config.override(args)
|
|
||||||
|
|
||||||
decoder = LMDecoder(
|
|
||||||
config,
|
|
||||||
embed_tokens,
|
|
||||||
output_projection,
|
|
||||||
dictionary=task.dictionary,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(args, decoder)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
|
||||||
return Embedding(len(dictionary), embed_dim, dictionary.pad())
|
|
||||||
|
|
||||||
|
|
||||||
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder):
|
|
||||||
def forward(self, src_tokens, **kwargs):
|
|
||||||
return super().forward(src_tokens, **kwargs)
|
|
||||||
|
|
||||||
def max_positions(self):
|
|
||||||
return self.args.max_target_positions
|
|
||||||
|
|
||||||
def reorder_incremental_state_scripting(
|
|
||||||
self,
|
|
||||||
incremental_state,
|
|
||||||
new_order,
|
|
||||||
):
|
|
||||||
for module in incremental_state:
|
|
||||||
for key in incremental_state[module]:
|
|
||||||
result = incremental_state[module][key].index_select(0, new_order)
|
|
||||||
incremental_state[module][key] = result
|
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_base")
|
|
||||||
def retnet_base_architecture(args):
|
|
||||||
# backward compatibility for older model checkpoints
|
|
||||||
if hasattr(args, "no_tie_adaptive_proj"):
|
|
||||||
# previous models defined --no-tie-adaptive-proj, so use the existence of
|
|
||||||
# that option to determine if this is an "old" model checkpoint
|
|
||||||
args.no_decoder_final_norm = True # old models always set this to True
|
|
||||||
if args.no_tie_adaptive_proj is False:
|
|
||||||
args.tie_adaptive_proj = True
|
|
||||||
if hasattr(args, "decoder_final_norm"):
|
|
||||||
args.no_decoder_final_norm = not args.decoder_final_norm
|
|
||||||
|
|
||||||
args.dropout = getattr(args, "dropout", 0.0)
|
|
||||||
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
|
|
||||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
|
||||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
|
||||||
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
|
|
||||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
|
||||||
args.activation_fn = getattr(args, "activation_fn", "swish")
|
|
||||||
|
|
||||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
|
||||||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
|
||||||
|
|
||||||
args.base_layers = getattr(args, "base_layers", 0)
|
|
||||||
args.base_sublayers = getattr(args, "base_sublayers", 1)
|
|
||||||
args.base_shuffle = getattr(args, "base_shuffle", False)
|
|
||||||
|
|
||||||
args.add_bos_token = getattr(args, "add_bos_token", False)
|
|
||||||
args.no_token_positional_embeddings = getattr(
|
|
||||||
args, "no_token_positional_embeddings", False
|
|
||||||
)
|
|
||||||
args.share_decoder_input_output_embed = getattr(
|
|
||||||
args, "share_decoder_input_output_embed", False
|
|
||||||
)
|
|
||||||
args.character_embeddings = getattr(args, "character_embeddings", False)
|
|
||||||
|
|
||||||
args.decoder_output_dim = getattr(
|
|
||||||
args, "decoder_output_dim", args.decoder_embed_dim
|
|
||||||
)
|
|
||||||
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
|
||||||
|
|
||||||
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False)
|
|
||||||
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512)
|
|
||||||
|
|
||||||
# Model training is not stable without this
|
|
||||||
args.decoder_normalize_before = True
|
|
||||||
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
|
|
||||||
|
|
||||||
args.adaptive_input = getattr(args, "adaptive_input", False)
|
|
||||||
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
|
|
||||||
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
|
|
||||||
|
|
||||||
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
|
||||||
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
|
||||||
|
|
||||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
|
||||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
|
||||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
|
||||||
args.offload_activations = getattr(args, "offload_activations", False)
|
|
||||||
if args.offload_activations:
|
|
||||||
args.checkpoint_activations = True
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_medium")
|
|
||||||
def retnet_medium(args):
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 16)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
|
|
||||||
retnet_base_architecture(args)
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_xl")
|
|
||||||
def retnet_xl(args):
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 8)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
|
||||||
retnet_base_architecture(args)
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_3b")
|
|
||||||
def retnet_3b(args):
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 10)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 32)
|
|
||||||
retnet_base_architecture(args)
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_7b")
|
|
||||||
def retnet_7b(args):
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 16)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 32)
|
|
||||||
retnet_base_architecture(args)
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_13b")
|
|
||||||
def retnet_13b(args):
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 20)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 40)
|
|
||||||
retnet_base_architecture(args)
|
|
||||||
|
|
||||||
@register_model_architecture("retnet", "retnet_65b")
|
|
||||||
def retnet_65b(args):
|
|
||||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
|
|
||||||
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
|
|
||||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
|
|
||||||
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 32)
|
|
||||||
args.decoder_layers = getattr(args, "decoder_layers", 64)
|
|
||||||
retnet_base_architecture(args)
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -17,7 +17,7 @@ setup(
|
||||||
license="MIT",
|
license="MIT",
|
||||||
url="https://github.com/microsoft/torchscale",
|
url="https://github.com/microsoft/torchscale",
|
||||||
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||||
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.6.13"],
|
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
||||||
python_requires=">=3.8.0",
|
python_requires=">=3.8.0",
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
@ -142,7 +142,6 @@ class EncoderDecoderConfig(object):
|
||||||
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072)
|
||||||
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
self.encoder_layers = kwargs.pop("encoder_layers", 12)
|
||||||
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True)
|
||||||
self.normalize_output = kwargs.pop("normalize_output", True)
|
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
||||||
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12)
|
||||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072)
|
||||||
|
@ -207,77 +206,3 @@ class EncoderDecoderConfig(object):
|
||||||
for hp in self.__dict__.keys():
|
for hp in self.__dict__.keys():
|
||||||
if getattr(args, hp, None) is not None:
|
if getattr(args, hp, None) is not None:
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
|
|
||||||
|
|
||||||
class RetNetConfig(object):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
|
|
||||||
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
|
|
||||||
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
|
|
||||||
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
|
|
||||||
self.decoder_layers = kwargs.pop("decoder_layers", 12)
|
|
||||||
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
|
|
||||||
self.activation_fn = kwargs.pop("activation_fn", "gelu")
|
|
||||||
self.dropout = kwargs.pop("dropout", 0.0)
|
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
|
||||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
|
||||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
|
||||||
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
|
||||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
|
||||||
"moe_eval_capacity_token_fraction", 0.25
|
|
||||||
)
|
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
|
||||||
"moe_normalize_gate_prob_before_dropping", False
|
|
||||||
)
|
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
|
||||||
self.deepnorm = kwargs.pop("deepnorm", False)
|
|
||||||
self.subln = kwargs.pop("subln", True)
|
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
|
||||||
self.share_decoder_input_output_embed = kwargs.pop(
|
|
||||||
"share_decoder_input_output_embed", False
|
|
||||||
)
|
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
|
||||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
|
|
||||||
# Blockwise
|
|
||||||
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
|
|
||||||
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
|
|
||||||
# Text
|
|
||||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
|
||||||
# Fairscale
|
|
||||||
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
|
|
||||||
self.fsdp = kwargs.pop("fsdp", False)
|
|
||||||
self.ddp_rank = kwargs.pop("ddp_rank", 0)
|
|
||||||
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
|
|
||||||
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
|
|
||||||
|
|
||||||
# RetNet's RelPos base
|
|
||||||
self.rotary_embedding_base = kwargs.pop("rotary_embedding_base", 10000)
|
|
||||||
|
|
||||||
# Backwards compatibility flags
|
|
||||||
self.use_layernorm = kwargs.pop("use_layernorm", False)
|
|
||||||
self.use_biases = kwargs.pop("use_biases", False)
|
|
||||||
self.use_glu = kwargs.pop("use_glu", True)
|
|
||||||
|
|
||||||
if self.deepnorm:
|
|
||||||
self.decoder_normalize_before = False
|
|
||||||
self.subln = False
|
|
||||||
if self.subln:
|
|
||||||
self.decoder_normalize_before = True
|
|
||||||
self.deepnorm = False
|
|
||||||
if self.use_xmoe:
|
|
||||||
self.moe_normalize_gate_prob_before_dropping = True
|
|
||||||
self.moe_second_expert_policy = "random"
|
|
||||||
assert self.moe_freq > 0 and self.moe_expert_count > 0
|
|
||||||
|
|
||||||
def override(self, args):
|
|
||||||
for hp in self.__dict__.keys():
|
|
||||||
if getattr(args, hp, None) is not None:
|
|
||||||
self.__dict__[hp] = getattr(args, hp, None)
|
|
||||||
|
|
|
@ -140,7 +140,6 @@ class DecoderLayer(nn.Module):
|
||||||
self_attn_padding_mask=None,
|
self_attn_padding_mask=None,
|
||||||
self_attn_rel_pos=None,
|
self_attn_rel_pos=None,
|
||||||
cross_attn_rel_pos=None,
|
cross_attn_rel_pos=None,
|
||||||
is_first_step=False,
|
|
||||||
):
|
):
|
||||||
residual = x
|
residual = x
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
|
@ -154,7 +153,6 @@ class DecoderLayer(nn.Module):
|
||||||
incremental_state=incremental_state,
|
incremental_state=incremental_state,
|
||||||
attn_mask=self_attn_mask,
|
attn_mask=self_attn_mask,
|
||||||
rel_pos=self_attn_rel_pos,
|
rel_pos=self_attn_rel_pos,
|
||||||
is_first_step=is_first_step,
|
|
||||||
)
|
)
|
||||||
x = self.dropout_module(x)
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
@ -359,7 +357,7 @@ class Decoder(nn.Module):
|
||||||
tokens, incremental_state=incremental_state
|
tokens, incremental_state=incremental_state
|
||||||
)
|
)
|
||||||
|
|
||||||
if incremental_state is not None and not self.is_first_step(incremental_state):
|
if incremental_state is not None:
|
||||||
tokens = tokens[:, -1:]
|
tokens = tokens[:, -1:]
|
||||||
if positions is not None:
|
if positions is not None:
|
||||||
positions = positions[:, -1:]
|
positions = positions[:, -1:]
|
||||||
|
@ -379,11 +377,6 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
return x, embed
|
return x, embed
|
||||||
|
|
||||||
def is_first_step(self, incremental_state):
|
|
||||||
if incremental_state is None:
|
|
||||||
return False
|
|
||||||
return incremental_state.get("is_first_step", False)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prev_output_tokens,
|
prev_output_tokens,
|
||||||
|
@ -399,7 +392,6 @@ class Decoder(nn.Module):
|
||||||
x, _ = self.forward_embedding(
|
x, _ = self.forward_embedding(
|
||||||
prev_output_tokens, token_embeddings, incremental_state
|
prev_output_tokens, token_embeddings, incremental_state
|
||||||
)
|
)
|
||||||
is_first_step = self.is_first_step(incremental_state)
|
|
||||||
|
|
||||||
# relative position
|
# relative position
|
||||||
self_attn_rel_pos_bias = None
|
self_attn_rel_pos_bias = None
|
||||||
|
@ -408,7 +400,7 @@ class Decoder(nn.Module):
|
||||||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||||
batch_size=x.size(0), qlen=slen, klen=slen
|
batch_size=x.size(0), qlen=slen, klen=slen
|
||||||
)
|
)
|
||||||
if incremental_state is not None and not is_first_step:
|
if incremental_state is not None:
|
||||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
self_attn_rel_pos_bias = self_attn_rel_pos_bias[-1:, :, :]
|
||||||
cross_attn_rel_pos_bias = None
|
cross_attn_rel_pos_bias = None
|
||||||
if self.cross_attn_relative_position is not None:
|
if self.cross_attn_relative_position is not None:
|
||||||
|
@ -417,7 +409,7 @@ class Decoder(nn.Module):
|
||||||
qlen=slen,
|
qlen=slen,
|
||||||
klen=encoder_out["encoder_out"].size(1),
|
klen=encoder_out["encoder_out"].size(1),
|
||||||
)
|
)
|
||||||
if incremental_state is not None and not is_first_step:
|
if incremental_state is not None:
|
||||||
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
cross_attn_rel_pos_bias = cross_attn_rel_pos_bias[-1:, :, :]
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
@ -429,7 +421,7 @@ class Decoder(nn.Module):
|
||||||
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
l_aux = encoder_out["l_aux"] if "l_aux" in encoder_out else []
|
||||||
|
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
if incremental_state is None or is_first_step:
|
if incremental_state is None:
|
||||||
self_attn_mask = torch.triu(
|
self_attn_mask = torch.triu(
|
||||||
torch.zeros([x.size(1), x.size(1)])
|
torch.zeros([x.size(1), x.size(1)])
|
||||||
.float()
|
.float()
|
||||||
|
@ -437,9 +429,6 @@ class Decoder(nn.Module):
|
||||||
.type_as(x),
|
.type_as(x),
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
if is_first_step and incremental_state is not None:
|
|
||||||
if idx not in incremental_state:
|
|
||||||
incremental_state[idx] = {}
|
|
||||||
else:
|
else:
|
||||||
self_attn_mask = None
|
self_attn_mask = None
|
||||||
if idx not in incremental_state:
|
if idx not in incremental_state:
|
||||||
|
@ -456,7 +445,6 @@ class Decoder(nn.Module):
|
||||||
self_attn_padding_mask=self_attn_padding_mask,
|
self_attn_padding_mask=self_attn_padding_mask,
|
||||||
self_attn_rel_pos=self_attn_rel_pos_bias,
|
self_attn_rel_pos=self_attn_rel_pos_bias,
|
||||||
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
cross_attn_rel_pos=cross_attn_rel_pos_bias,
|
||||||
is_first_step=is_first_step,
|
|
||||||
)
|
)
|
||||||
l_aux.append(l_aux_i)
|
l_aux.append(l_aux_i)
|
||||||
inner_states.append(x)
|
inner_states.append(x)
|
||||||
|
|
|
@ -1,403 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from fairscale.nn import checkpoint_wrapper, wrap
|
|
||||||
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
|
||||||
from torchscale.component.droppath import DropPath
|
|
||||||
from torchscale.component.feedforward_network import make_experts, FeedForwardNetwork
|
|
||||||
from torchscale.component.gate_linear_unit import GLU
|
|
||||||
from torchscale.component.multiscale_retention import MultiScaleRetention
|
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
|
||||||
try:
|
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from torch.nn import LayerNorm
|
|
||||||
from torchscale.component.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetRelPos(nn.Module):
|
|
||||||
def __init__(self, args):
|
|
||||||
super().__init__()
|
|
||||||
angle = 1.0 / (args.rotary_embedding_base ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
|
|
||||||
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
|
|
||||||
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
|
|
||||||
self.register_buffer("angle", angle)
|
|
||||||
self.register_buffer("decay", decay)
|
|
||||||
self.recurrent_chunk_size = args.recurrent_chunk_size
|
|
||||||
|
|
||||||
def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
|
|
||||||
if activate_recurrent:
|
|
||||||
sin = torch.sin(self.angle * (slen - 1))
|
|
||||||
cos = torch.cos(self.angle * (slen - 1))
|
|
||||||
retention_rel_pos = ((sin, cos), self.decay.exp())
|
|
||||||
elif chunkwise_recurrent:
|
|
||||||
index = torch.arange(slen).to(self.decay)
|
|
||||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
|
||||||
cos = torch.cos(index[:, None] * self.angle[None, :])
|
|
||||||
|
|
||||||
block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
|
|
||||||
mask = torch.tril(torch.ones(self.recurrent_chunk_size, self.recurrent_chunk_size).to(self.decay))
|
|
||||||
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
|
|
||||||
mask = torch.exp(mask * self.decay[:, None, None])
|
|
||||||
mask = torch.nan_to_num(mask)
|
|
||||||
|
|
||||||
value_inner_decay = mask[:, -1] / mask[:, -1].sum(dim=-1, keepdim=True)
|
|
||||||
value_inner_decay = value_inner_decay.unsqueeze(-1)
|
|
||||||
scale = mask.sum(dim=-1, keepdim=True).sqrt()
|
|
||||||
inner_mask = mask / scale
|
|
||||||
|
|
||||||
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
|
|
||||||
query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
|
|
||||||
query_inner_decay = query_inner_decay[:, :, None] / (scale / mask[:, -1].sum(dim=-1)[:, None, None])
|
|
||||||
cross_decay = cross_decay[:, None, None]
|
|
||||||
retention_rel_pos = ((sin, cos), (inner_mask, cross_decay, query_inner_decay, value_inner_decay))
|
|
||||||
else:
|
|
||||||
index = torch.arange(slen).to(self.decay)
|
|
||||||
sin = torch.sin(index[:, None] * self.angle[None, :])
|
|
||||||
cos = torch.cos(index[:, None] * self.angle[None, :])
|
|
||||||
mask = torch.tril(torch.ones(slen, slen).to(self.decay))
|
|
||||||
mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
|
|
||||||
mask = torch.exp(mask * self.decay[:, None, None])
|
|
||||||
mask = torch.nan_to_num(mask)
|
|
||||||
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
|
|
||||||
retention_rel_pos = ((sin, cos), mask)
|
|
||||||
|
|
||||||
return retention_rel_pos
|
|
||||||
|
|
||||||
class DecoderLayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
depth,
|
|
||||||
is_moe_layer=False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.embed_dim = args.decoder_embed_dim
|
|
||||||
self.dropout_module = torch.nn.Dropout(args.dropout)
|
|
||||||
|
|
||||||
if args.drop_path_rate > 0:
|
|
||||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
|
|
||||||
depth
|
|
||||||
]
|
|
||||||
self.drop_path = DropPath(drop_path_prob)
|
|
||||||
else:
|
|
||||||
self.drop_path = None
|
|
||||||
|
|
||||||
self.retention = self.build_retention(self.embed_dim, args)
|
|
||||||
|
|
||||||
self.normalize_before = args.decoder_normalize_before
|
|
||||||
|
|
||||||
self.retention_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
|
||||||
|
|
||||||
self.is_moe_layer = is_moe_layer
|
|
||||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
|
||||||
|
|
||||||
if not self.is_moe_layer:
|
|
||||||
self.ffn = self.build_ffn(
|
|
||||||
self.embed_dim,
|
|
||||||
self.args,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if args.moe_top1_expert:
|
|
||||||
gate = Top1Gate(
|
|
||||||
self.embed_dim,
|
|
||||||
args.moe_expert_count,
|
|
||||||
use_fp32=args.moe_gating_use_fp32,
|
|
||||||
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
|
|
||||||
use_xmoe=args.use_xmoe,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
gate = Top2Gate(
|
|
||||||
self.embed_dim,
|
|
||||||
args.moe_expert_count,
|
|
||||||
args.moe_gating_use_fp32,
|
|
||||||
args.moe_second_expert_policy,
|
|
||||||
args.moe_normalize_gate_prob_before_dropping,
|
|
||||||
args.moe_eval_capacity_token_fraction,
|
|
||||||
use_xmoe=args.use_xmoe,
|
|
||||||
)
|
|
||||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
|
||||||
self.moe_layer = MOELayer(gate, experts, args)
|
|
||||||
|
|
||||||
self.final_layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(self.embed_dim, eps=args.layernorm_eps)
|
|
||||||
|
|
||||||
if args.deepnorm:
|
|
||||||
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
|
||||||
else:
|
|
||||||
self.alpha = 1.0
|
|
||||||
|
|
||||||
def build_ffn(self, embed_dim, args):
|
|
||||||
return GLU(
|
|
||||||
embed_dim,
|
|
||||||
self.ffn_dim,
|
|
||||||
args.activation_fn,
|
|
||||||
args.dropout,
|
|
||||||
args.activation_dropout,
|
|
||||||
) if args.use_glu else FeedForwardNetwork(
|
|
||||||
embed_dim,
|
|
||||||
self.ffn_dim,
|
|
||||||
args.activation_fn,
|
|
||||||
args.dropout,
|
|
||||||
args.activation_dropout,
|
|
||||||
args.layernorm_eps,
|
|
||||||
args.subln,
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_retention(self, embed_dim, args):
|
|
||||||
return MultiScaleRetention(
|
|
||||||
args,
|
|
||||||
embed_dim,
|
|
||||||
args.decoder_value_embed_dim,
|
|
||||||
args.decoder_retention_heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
def residual_connection(self, x, residual):
|
|
||||||
return residual * self.alpha + x
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
incremental_state=None,
|
|
||||||
chunkwise_recurrent=False,
|
|
||||||
retention_rel_pos=None,
|
|
||||||
):
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.retention_layer_norm(x)
|
|
||||||
|
|
||||||
x = self.retention(
|
|
||||||
x,
|
|
||||||
incremental_state=incremental_state,
|
|
||||||
rel_pos=retention_rel_pos,
|
|
||||||
chunkwise_recurrent=chunkwise_recurrent,
|
|
||||||
)
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
x = self.residual_connection(x, residual)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.retention_layer_norm(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
if not self.is_moe_layer:
|
|
||||||
x = self.ffn(x)
|
|
||||||
l_aux = None
|
|
||||||
else:
|
|
||||||
x, l_aux = self.moe_layer(x)
|
|
||||||
|
|
||||||
if self.drop_path is not None:
|
|
||||||
x = self.drop_path(x)
|
|
||||||
|
|
||||||
x = self.residual_connection(x, residual)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.final_layer_norm(x)
|
|
||||||
|
|
||||||
return x, l_aux
|
|
||||||
|
|
||||||
|
|
||||||
class RetNetDecoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
embed_tokens=None,
|
|
||||||
output_projection=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
self.dropout_module = torch.nn.Dropout(args.dropout)
|
|
||||||
|
|
||||||
embed_dim = args.decoder_embed_dim
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
|
||||||
|
|
||||||
self.embed_tokens = embed_tokens
|
|
||||||
|
|
||||||
if (
|
|
||||||
output_projection is None
|
|
||||||
and not args.no_output_layer
|
|
||||||
and args.vocab_size > 0
|
|
||||||
):
|
|
||||||
self.output_projection = self.build_output_projection(args)
|
|
||||||
else:
|
|
||||||
self.output_projection = output_projection
|
|
||||||
|
|
||||||
if args.layernorm_embedding:
|
|
||||||
self.layernorm_embedding = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
|
||||||
else:
|
|
||||||
self.layernorm_embedding = None
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
|
||||||
|
|
||||||
moe_freq = args.moe_freq
|
|
||||||
for i in range(args.decoder_layers):
|
|
||||||
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
|
|
||||||
self.layers.append(
|
|
||||||
self.build_decoder_layer(
|
|
||||||
args,
|
|
||||||
depth=i,
|
|
||||||
is_moe_layer=is_moe_layer,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_layers = len(self.layers)
|
|
||||||
|
|
||||||
if args.decoder_normalize_before:
|
|
||||||
self.layer_norm = (LayerNorm if args.use_layernorm else RMSNorm)(embed_dim, eps=args.layernorm_eps)
|
|
||||||
else:
|
|
||||||
self.layer_norm = None
|
|
||||||
|
|
||||||
self.retnet_rel_pos = RetNetRelPos(args)
|
|
||||||
self.chunkwise_recurrent = args.chunkwise_recurrent
|
|
||||||
self.recurrent_chunk_size = args.recurrent_chunk_size
|
|
||||||
|
|
||||||
|
|
||||||
if args.deepnorm:
|
|
||||||
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
|
||||||
for name, p in self.named_parameters():
|
|
||||||
if (
|
|
||||||
"fc1" in name
|
|
||||||
or "fc2" in name
|
|
||||||
or "out_proj" in name
|
|
||||||
or "v_proj" in name
|
|
||||||
):
|
|
||||||
p.data.div_(init_scale)
|
|
||||||
|
|
||||||
def build_output_projection(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
):
|
|
||||||
if args.share_decoder_input_output_embed:
|
|
||||||
output_projection = torch.nn.Linear(
|
|
||||||
self.embed_tokens.weight.shape[1],
|
|
||||||
self.embed_tokens.weight.shape[0],
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
output_projection.weight = self.embed_tokens.weight
|
|
||||||
else:
|
|
||||||
output_projection = torch.nn.Linear(
|
|
||||||
args.decoder_embed_dim, args.vocab_size, bias=False
|
|
||||||
)
|
|
||||||
torch.nn.init.normal_(
|
|
||||||
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
|
||||||
)
|
|
||||||
return output_projection
|
|
||||||
|
|
||||||
def build_decoder_layer(
|
|
||||||
self, args, depth, is_moe_layer=False
|
|
||||||
):
|
|
||||||
layer = DecoderLayer(
|
|
||||||
args,
|
|
||||||
depth,
|
|
||||||
is_moe_layer=is_moe_layer,
|
|
||||||
)
|
|
||||||
if args.checkpoint_activations:
|
|
||||||
layer = checkpoint_wrapper(layer)
|
|
||||||
if args.fsdp:
|
|
||||||
layer = wrap(layer)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
def forward_embedding(
|
|
||||||
self,
|
|
||||||
tokens,
|
|
||||||
token_embedding=None,
|
|
||||||
incremental_state=None,
|
|
||||||
):
|
|
||||||
if incremental_state is not None and not self.is_first_step(incremental_state):
|
|
||||||
tokens = tokens[:, -1:]
|
|
||||||
|
|
||||||
if token_embedding is None:
|
|
||||||
token_embedding = self.embed_tokens(tokens)
|
|
||||||
|
|
||||||
x = embed = self.embed_scale * token_embedding
|
|
||||||
|
|
||||||
if self.layernorm_embedding is not None:
|
|
||||||
x = self.layernorm_embedding(x)
|
|
||||||
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
|
|
||||||
return x, embed
|
|
||||||
|
|
||||||
def is_first_step(self, incremental_state):
|
|
||||||
if incremental_state is None:
|
|
||||||
return False
|
|
||||||
return incremental_state.get("is_first_step", False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
prev_output_tokens,
|
|
||||||
incremental_state=None,
|
|
||||||
features_only=False,
|
|
||||||
return_all_hiddens=False,
|
|
||||||
token_embeddings=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
# embed tokens
|
|
||||||
x, _ = self.forward_embedding(
|
|
||||||
prev_output_tokens, token_embeddings, incremental_state
|
|
||||||
)
|
|
||||||
is_first_step = self.is_first_step(incremental_state)
|
|
||||||
|
|
||||||
|
|
||||||
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
|
|
||||||
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(1) % self.recurrent_chunk_size
|
|
||||||
slen = prev_output_tokens.size(1) + padding_len
|
|
||||||
x = F.pad(x, (0, 0, 0, padding_len))
|
|
||||||
else:
|
|
||||||
slen = prev_output_tokens.size(1)
|
|
||||||
# relative position
|
|
||||||
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
|
||||||
# decoder layers
|
|
||||||
inner_states = [x]
|
|
||||||
|
|
||||||
l_aux = []
|
|
||||||
|
|
||||||
for idx, layer in enumerate(self.layers):
|
|
||||||
if incremental_state is None or is_first_step:
|
|
||||||
if is_first_step and incremental_state is not None:
|
|
||||||
if idx not in incremental_state:
|
|
||||||
incremental_state[idx] = {}
|
|
||||||
else:
|
|
||||||
if idx not in incremental_state:
|
|
||||||
incremental_state[idx] = {}
|
|
||||||
|
|
||||||
x, l_aux_i = layer(
|
|
||||||
x,
|
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
|
||||||
retention_rel_pos=retention_rel_pos,
|
|
||||||
chunkwise_recurrent=self.chunkwise_recurrent,
|
|
||||||
)
|
|
||||||
l_aux.append(l_aux_i)
|
|
||||||
inner_states.append(x)
|
|
||||||
|
|
||||||
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
|
|
||||||
x = x[:, :prev_output_tokens.size(1), :]
|
|
||||||
|
|
||||||
if self.layer_norm is not None:
|
|
||||||
x = self.layer_norm(x)
|
|
||||||
|
|
||||||
if not features_only:
|
|
||||||
x = self.output_layer(x)
|
|
||||||
|
|
||||||
return x, {
|
|
||||||
"inner_states": inner_states,
|
|
||||||
"l_aux": l_aux,
|
|
||||||
"attn": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
def output_layer(self, features):
|
|
||||||
return self.output_projection(features)
|
|
|
@ -10,9 +10,6 @@ except ModuleNotFoundError:
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
from .xmoe.global_groups import get_moe_group
|
|
||||||
|
|
||||||
|
|
||||||
class set_torch_seed(object):
|
class set_torch_seed(object):
|
||||||
def __init__(self, seed):
|
def __init__(self, seed):
|
||||||
assert isinstance(seed, int)
|
assert isinstance(seed, int)
|
||||||
|
@ -73,9 +70,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
||||||
world_size % args.moe_expert_count == 0
|
world_size % args.moe_expert_count == 0
|
||||||
), f"{world_size}, {args.moe_expert_count}"
|
), f"{world_size}, {args.moe_expert_count}"
|
||||||
|
|
||||||
moe_idx, _ = get_moe_group(args.moe_expert_count)
|
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
||||||
|
|
||||||
with set_torch_seed(start_seed + moe_idx):
|
|
||||||
expert_list.append(
|
expert_list.append(
|
||||||
FeedForwardNetwork(
|
FeedForwardNetwork(
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
@ -96,8 +91,6 @@ def get_activation_fn(activation):
|
||||||
return F.relu
|
return F.relu
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
return F.gelu
|
return F.gelu
|
||||||
elif activation == "swish":
|
|
||||||
return F.silu
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -1,44 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .feedforward_network import get_activation_fn
|
|
||||||
|
|
||||||
|
|
||||||
class GLU(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim,
|
|
||||||
ffn_dim,
|
|
||||||
activation_fn,
|
|
||||||
dropout,
|
|
||||||
activation_dropout,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
||||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
||||||
self.dropout_module = torch.nn.Dropout(dropout)
|
|
||||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
|
||||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
|
||||||
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
self.fc1.reset_parameters()
|
|
||||||
self.fc2.reset_parameters()
|
|
||||||
self.gate.reset_parameters()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x_shape = x.shape
|
|
||||||
x = x.reshape(-1, x.size(-1))
|
|
||||||
g = self.gate(x)
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.activation_fn(x.float()).type_as(x) * g
|
|
||||||
x = self.activation_dropout_module(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = x.view(x_shape)
|
|
||||||
x = self.dropout_module(x)
|
|
||||||
return x
|
|
|
@ -71,7 +71,6 @@ class MultiheadAttention(nn.Module):
|
||||||
key_padding_mask=None,
|
key_padding_mask=None,
|
||||||
attn_mask=None,
|
attn_mask=None,
|
||||||
rel_pos=None,
|
rel_pos=None,
|
||||||
is_first_step=False,
|
|
||||||
):
|
):
|
||||||
bsz, tgt_len, embed_dim = query.size()
|
bsz, tgt_len, embed_dim = query.size()
|
||||||
src_len = tgt_len
|
src_len = tgt_len
|
||||||
|
@ -113,7 +112,7 @@ class MultiheadAttention(nn.Module):
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
if self.xpos is not None:
|
if self.xpos is not None:
|
||||||
if incremental_state is not None and not is_first_step:
|
if incremental_state is not None:
|
||||||
offset = src_len - 1
|
offset = src_len - 1
|
||||||
else:
|
else:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
|
|
@ -1,210 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
|
||||||
try:
|
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
from torch.nn import LayerNorm
|
|
||||||
from .rms_norm import RMSNorm
|
|
||||||
|
|
||||||
from .multiway_network import MultiwayWrapper
|
|
||||||
|
|
||||||
def rotate_every_two(x):
|
|
||||||
x1 = x[:, :, :, ::2]
|
|
||||||
x2 = x[:, :, :, 1::2]
|
|
||||||
x = torch.stack((-x2, x1), dim=-1)
|
|
||||||
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
|
|
||||||
|
|
||||||
def duplicate_interleave(m):
|
|
||||||
"""
|
|
||||||
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
|
|
||||||
"""
|
|
||||||
dim0 = m.shape[0]
|
|
||||||
m = m.view(-1, 1) # flatten the matrix
|
|
||||||
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
|
|
||||||
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
|
|
||||||
return m
|
|
||||||
|
|
||||||
def theta_shift(x, sin, cos):
|
|
||||||
return (x * cos) + (rotate_every_two(x) * sin)
|
|
||||||
|
|
||||||
def get_activation_fn(activation):
|
|
||||||
if activation == "swish":
|
|
||||||
return F.silu
|
|
||||||
elif activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
class MultiScaleRetention(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
embed_dim,
|
|
||||||
value_dim,
|
|
||||||
num_heads,
|
|
||||||
gate_fn="swish",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.value_dim = value_dim
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = self.value_dim // num_heads
|
|
||||||
self.key_dim = self.embed_dim // num_heads
|
|
||||||
self.scaling = self.key_dim ** -0.5
|
|
||||||
|
|
||||||
self.gate_fn = get_activation_fn(activation=str(gate_fn))
|
|
||||||
|
|
||||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
|
|
||||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=args.use_biases))
|
|
||||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
|
|
||||||
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=args.use_biases))
|
|
||||||
|
|
||||||
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=args.use_biases))
|
|
||||||
|
|
||||||
self.group_norm = MultiwayWrapper(args, (LayerNorm if args.use_layernorm else RMSNorm)(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
|
|
||||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
|
|
||||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
|
|
||||||
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
|
|
||||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
|
||||||
if hasattr(self.out_proj, "bias"):
|
|
||||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
|
||||||
|
|
||||||
def parallel_forward(self, qr, kr, v, mask):
|
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
|
||||||
|
|
||||||
vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
|
|
||||||
qk_mat = qk_mat * mask
|
|
||||||
# invariant after normalization
|
|
||||||
qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)
|
|
||||||
output = torch.matmul(qk_mat, vr)
|
|
||||||
output = output.transpose(1, 2)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def recurrent_forward(
|
|
||||||
self,
|
|
||||||
qr, kr, v,
|
|
||||||
decay,
|
|
||||||
incremental_state
|
|
||||||
):
|
|
||||||
bsz = v.size(0)
|
|
||||||
|
|
||||||
v = v.view(bsz, self.num_heads, self.head_dim, 1)
|
|
||||||
kv = kr * v
|
|
||||||
if "prev_key_value" in incremental_state:
|
|
||||||
prev_kv = incremental_state["prev_key_value"]
|
|
||||||
prev_scale = incremental_state["scale"]
|
|
||||||
scale = prev_scale * decay + 1
|
|
||||||
kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
|
|
||||||
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
|
|
||||||
else:
|
|
||||||
scale = torch.ones_like(decay)
|
|
||||||
|
|
||||||
incremental_state["prev_key_value"] = kv
|
|
||||||
incremental_state["scale"] = scale
|
|
||||||
|
|
||||||
output = torch.sum(qr * kv, dim=3)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def chunk_recurrent_forward(
|
|
||||||
self,
|
|
||||||
qr, kr, v,
|
|
||||||
inner_mask
|
|
||||||
):
|
|
||||||
mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
|
|
||||||
bsz, tgt_len, embed_dim = v.size()
|
|
||||||
chunk_len = mask.size(1)
|
|
||||||
num_chunks = tgt_len // chunk_len
|
|
||||||
|
|
||||||
assert tgt_len % chunk_len == 0
|
|
||||||
|
|
||||||
qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
|
|
||||||
kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
|
|
||||||
v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
|
|
||||||
|
|
||||||
kr_t = kr.transpose(-1, -2)
|
|
||||||
|
|
||||||
qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
|
|
||||||
qk_mat = qk_mat * mask
|
|
||||||
inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
|
|
||||||
qk_mat = qk_mat / inner_scale
|
|
||||||
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
|
|
||||||
|
|
||||||
# reduce kv in one chunk
|
|
||||||
kv = kr_t @ (v * value_inner_decay)
|
|
||||||
|
|
||||||
kv_recurrent = []
|
|
||||||
cross_scale = []
|
|
||||||
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
|
|
||||||
kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
|
|
||||||
|
|
||||||
# accumulate kv by loop
|
|
||||||
for i in range(num_chunks):
|
|
||||||
kv_recurrent.append(kv_state / kv_scale)
|
|
||||||
cross_scale.append(kv_scale)
|
|
||||||
kv_state = kv_state * cross_decay + kv[:, i]
|
|
||||||
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
|
|
||||||
|
|
||||||
kv_recurrent = torch.stack(kv_recurrent, dim=1)
|
|
||||||
cross_scale = torch.stack(cross_scale, dim=1)
|
|
||||||
|
|
||||||
all_scale = torch.maximum(inner_scale, cross_scale)
|
|
||||||
align_inner_scale = all_scale / inner_scale
|
|
||||||
align_cross_scale = all_scale / cross_scale
|
|
||||||
|
|
||||||
cross_output = (qr * query_inner_decay) @ kv_recurrent
|
|
||||||
output = inner_output / align_inner_scale + cross_output / align_cross_scale
|
|
||||||
# output = inner_output / cross_scale + cross_output / inner_scale
|
|
||||||
|
|
||||||
output = output.transpose(2, 3)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
rel_pos,
|
|
||||||
chunkwise_recurrent=False,
|
|
||||||
incremental_state=None
|
|
||||||
):
|
|
||||||
bsz, tgt_len, _ = x.size()
|
|
||||||
(sin, cos), inner_mask = rel_pos
|
|
||||||
|
|
||||||
q = self.q_proj(x)
|
|
||||||
k = self.k_proj(x)
|
|
||||||
v = self.v_proj(x)
|
|
||||||
g = self.g_proj(x)
|
|
||||||
|
|
||||||
k *= self.scaling
|
|
||||||
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
||||||
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
qr = theta_shift(q, sin, cos)
|
|
||||||
kr = theta_shift(k, sin, cos)
|
|
||||||
|
|
||||||
if incremental_state is not None:
|
|
||||||
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
|
||||||
elif chunkwise_recurrent:
|
|
||||||
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
|
||||||
else:
|
|
||||||
output = self.parallel_forward(qr, kr, v, inner_mask)
|
|
||||||
|
|
||||||
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
|
||||||
|
|
||||||
output = self.gate_fn(g) * output
|
|
||||||
|
|
||||||
output = self.out_proj(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
if self.elementwise_affine:
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
else:
|
|
||||||
self.register_parameter('weight', None)
|
|
||||||
|
|
||||||
def _norm(self, x):
|
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
output = self._norm(x.float()).type_as(x)
|
|
||||||
if self.weight is not None:
|
|
||||||
output = output * self.weight
|
|
||||||
return output
|
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
def _find_my_group_index(grouped_ranks):
|
|
||||||
my_rank = dist.get_rank()
|
|
||||||
for i, group in enumerate(grouped_ranks):
|
|
||||||
if my_rank in group:
|
|
||||||
return i
|
|
||||||
raise RuntimeError
|
|
||||||
|
|
||||||
def get_moe_group(moe_expert_count=None):
|
|
||||||
if dist.is_initialized():
|
|
||||||
if not hasattr(get_moe_group, "_moe_groups"):
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
|
|
||||||
if world_size <= moe_expert_count:
|
|
||||||
assert moe_expert_count % world_size == 0
|
|
||||||
moe_groups = [[i] for i in range(world_size)]
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert world_size % moe_expert_count == 0
|
|
||||||
ranks_per_group = world_size // moe_expert_count
|
|
||||||
moe_groups = [
|
|
||||||
[i + j * moe_expert_count for j in range(ranks_per_group)]
|
|
||||||
for i in range(moe_expert_count)
|
|
||||||
]
|
|
||||||
|
|
||||||
get_moe_group._moe_expert_count = moe_expert_count
|
|
||||||
get_moe_group._moe_group_idx = moe_groups
|
|
||||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
|
||||||
|
|
||||||
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
|
|
||||||
return my_group_idx, get_moe_group._moe_groups[my_group_idx]
|
|
||||||
|
|
||||||
|
|
||||||
def get_all2all_group(moe_expert_count):
|
|
||||||
if dist.is_initialized():
|
|
||||||
if not hasattr(get_all2all_group, "_all2all_groups"):
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
|
|
||||||
# more experts than world size
|
|
||||||
if world_size <= moe_expert_count:
|
|
||||||
assert moe_expert_count % world_size == 0
|
|
||||||
all2all_groups = [[i for i in range(world_size)]]
|
|
||||||
|
|
||||||
# larger world than num experts
|
|
||||||
else:
|
|
||||||
assert world_size % moe_expert_count == 0
|
|
||||||
ranks_per_group = world_size // moe_expert_count
|
|
||||||
all2all_groups = [
|
|
||||||
[i * moe_expert_count + j for j in range(moe_expert_count)]
|
|
||||||
for i in range(ranks_per_group)
|
|
||||||
]
|
|
||||||
|
|
||||||
get_all2all_group._all2all_group_idx = all2all_groups
|
|
||||||
get_all2all_group._all2all_groups = [
|
|
||||||
dist.new_group(g) for g in all2all_groups
|
|
||||||
]
|
|
||||||
|
|
||||||
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
|
||||||
return get_all2all_group._all2all_groups[my_group_idx]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,6 @@ import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
|
|
||||||
from .global_groups import get_all2all_group, get_moe_group
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fairseq.modules.moe import MOELayer
|
from fairseq.modules.moe import MOELayer
|
||||||
|
|
||||||
|
@ -63,6 +61,64 @@ class _AllToAll(torch.autograd.Function):
|
||||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||||
|
|
||||||
|
|
||||||
|
def _find_my_group_index(grouped_ranks):
|
||||||
|
my_rank = dist.get_rank()
|
||||||
|
for i, group in enumerate(grouped_ranks):
|
||||||
|
if my_rank in group:
|
||||||
|
return i
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_group(moe_expert_count):
|
||||||
|
if dist.is_initialized():
|
||||||
|
if not hasattr(get_moe_group, "_moe_groups"):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
if world_size <= moe_expert_count:
|
||||||
|
assert moe_expert_count % world_size == 0
|
||||||
|
moe_groups = [[i] for i in range(world_size)]
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert world_size % moe_expert_count == 0
|
||||||
|
ranks_per_group = world_size // moe_expert_count
|
||||||
|
moe_groups = [
|
||||||
|
[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||||
|
for i in range(moe_expert_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
get_moe_group._moe_group_idx = moe_groups
|
||||||
|
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||||
|
|
||||||
|
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
|
||||||
|
return get_moe_group._moe_groups[my_group_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def get_all2all_group(moe_expert_count):
|
||||||
|
if dist.is_initialized():
|
||||||
|
if not hasattr(get_all2all_group, "_all2all_groups"):
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
# more experts than world size
|
||||||
|
if world_size <= moe_expert_count:
|
||||||
|
assert moe_expert_count % world_size == 0
|
||||||
|
all2all_groups = [[i for i in range(world_size)]]
|
||||||
|
|
||||||
|
# larger world than num experts
|
||||||
|
else:
|
||||||
|
assert world_size % moe_expert_count == 0
|
||||||
|
ranks_per_group = world_size // moe_expert_count
|
||||||
|
all2all_groups = [
|
||||||
|
[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||||
|
for i in range(ranks_per_group)
|
||||||
|
]
|
||||||
|
|
||||||
|
get_all2all_group._all2all_group_idx = all2all_groups
|
||||||
|
get_all2all_group._all2all_groups = [
|
||||||
|
dist.new_group(g) for g in all2all_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||||
|
return get_all2all_group._all2all_groups[my_group_idx]
|
||||||
|
|
||||||
|
|
||||||
class MOELayer(Base):
|
class MOELayer(Base):
|
||||||
|
@ -93,7 +149,7 @@ class MOELayer(Base):
|
||||||
self.experts = cast(ModuleList, experts)
|
self.experts = cast(ModuleList, experts)
|
||||||
else:
|
else:
|
||||||
self.experts = ModuleList([experts])
|
self.experts = ModuleList([experts])
|
||||||
_, self.expert_group = get_moe_group(args.moe_expert_count)
|
self.expert_group = get_moe_group(args.moe_expert_count)
|
||||||
self.all2all_group = get_all2all_group(args.moe_expert_count)
|
self.all2all_group = get_all2all_group(args.moe_expert_count)
|
||||||
self.world_size = dist.get_world_size(group=self.expert_group)
|
self.world_size = dist.get_world_size(group=self.expert_group)
|
||||||
self.all2all_size = dist.get_world_size(group=self.all2all_group)
|
self.all2all_size = dist.get_world_size(group=self.all2all_group)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user