454 lines
22 KiB
Python
454 lines
22 KiB
Python
# Copyright (c) 2022 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from fairseq import distributed_utils, utils
|
|
from fairseq.distributed import fsdp_wrap
|
|
from fairseq.models import (
|
|
FairseqEncoder,
|
|
FairseqEncoderDecoderModel,
|
|
register_model,
|
|
register_model_architecture,
|
|
)
|
|
from fairseq.models.transformer import Embedding
|
|
from fairseq.modules import PositionalEmbedding
|
|
from torch import Tensor
|
|
|
|
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
|
|
from torchscale.architecture.encoder import Encoder
|
|
|
|
from .language_modeling import LMDecoder as MTDecoder
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
|
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
|
|
|
|
|
@register_model("mt")
|
|
class TranslationModel(FairseqEncoderDecoderModel):
|
|
def __init__(self, args, encoder, decoder):
|
|
super().__init__(encoder, decoder)
|
|
self.args = args
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
"""Add model-specific arguments to the parser."""
|
|
# fmt: off
|
|
parser.add_argument('--activation-fn',
|
|
choices=utils.get_available_activation_fns(),
|
|
help='activation function to use')
|
|
parser.add_argument('--dropout', type=float, metavar='D',
|
|
help='dropout probability')
|
|
parser.add_argument('--attention-dropout', type=float, metavar='D',
|
|
help='dropout probability for attention weights')
|
|
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
|
|
help='dropout probability after activation in FFN.')
|
|
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
|
|
help='path to pre-trained encoder embedding')
|
|
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
|
|
help='encoder embedding dimension')
|
|
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
|
|
help='encoder embedding dimension for FFN')
|
|
parser.add_argument('--encoder-layers', type=int, metavar='N',
|
|
help='num encoder layers')
|
|
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
|
|
help='num encoder attention heads')
|
|
parser.add_argument('--encoder-normalize-before', action='store_true',
|
|
help='apply layernorm before each encoder block')
|
|
parser.add_argument('--encoder-learned-pos', action='store_true',
|
|
help='use learned positional embeddings in the encoder')
|
|
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
|
|
help='path to pre-trained decoder embedding')
|
|
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
|
|
help='decoder embedding dimension')
|
|
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
|
|
help='decoder embedding dimension for FFN')
|
|
parser.add_argument('--decoder-layers', type=int, metavar='N',
|
|
help='num decoder layers')
|
|
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
|
|
help='num decoder attention heads')
|
|
parser.add_argument('--decoder-learned-pos', action='store_true',
|
|
help='use learned positional embeddings in the decoder')
|
|
parser.add_argument('--decoder-normalize-before', action='store_true',
|
|
help='apply layernorm before each decoder block')
|
|
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
|
|
help='decoder output dimension (extra linear layer '
|
|
'if different from decoder embed dim')
|
|
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
|
|
help='share decoder input and output embeddings')
|
|
parser.add_argument('--share-all-embeddings', action='store_true',
|
|
help='share encoder, decoder and output embeddings'
|
|
' (requires shared dictionary and embed dim)')
|
|
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
|
|
help='if set, disables positional embeddings (outside self attention)')
|
|
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
|
|
help='comma separated list of adaptive softmax cutoff points. '
|
|
'Must be used with adaptive_loss criterion'),
|
|
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
|
|
help='sets adaptive softmax dropout for the tail projections')
|
|
parser.add_argument('--layernorm-embedding', action='store_true',
|
|
help='add layernorm to embedding')
|
|
parser.add_argument('--no-scale-embedding', action='store_true',
|
|
help='if True, dont scale embeddings')
|
|
parser.add_argument('--checkpoint-activations', action='store_true',
|
|
help='checkpoint activations at each layer, which saves GPU '
|
|
'memory usage at the cost of some additional compute')
|
|
parser.add_argument('--offload-activations', action='store_true',
|
|
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
|
|
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
|
|
parser.add_argument('--no-cross-attention', default=False, action='store_true',
|
|
help='do not perform cross-attention')
|
|
parser.add_argument('--cross-self-attention', default=False, action='store_true',
|
|
help='perform cross+self-attention')
|
|
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
|
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
|
|
help='LayerDrop probability for encoder')
|
|
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
|
|
help='LayerDrop probability for decoder')
|
|
parser.add_argument('--encoder-layers-to-keep', default=None,
|
|
help='which layers to *keep* when pruning as a comma-separated list')
|
|
parser.add_argument('--decoder-layers-to-keep', default=None,
|
|
help='which layers to *keep* when pruning as a comma-separated list')
|
|
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
|
|
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
|
|
help='iterative PQ quantization noise at training time')
|
|
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
|
|
help='block size of quantization noise at training time')
|
|
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
|
|
help='scalar quantization noise and scalar quantization at training time')
|
|
# args for Fully Sharded Data Parallel (FSDP) training
|
|
parser.add_argument(
|
|
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
|
help=(
|
|
'minimum number of params for a layer to be wrapped with FSDP() when '
|
|
'training with --ddp-backend=fully_sharded. Smaller values will '
|
|
'improve memory efficiency, but may make torch.distributed '
|
|
'communication less efficient due to smaller input sizes. This option '
|
|
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
|
|
'--offload-activations are passed.'
|
|
)
|
|
)
|
|
# args for mixture-of-expert layers
|
|
parser.add_argument('--moe-freq', type=int, metavar='D', default=0,
|
|
help='Frequency at which we insert MoE Transformer layers')
|
|
parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0,
|
|
help='Frequency at which we insert MoE Transformer encoder layers')
|
|
parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0,
|
|
help='Frequency at which we insert MoE Transformer decoder layers')
|
|
parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0,
|
|
help='Number of experts in each MoE Layer')
|
|
parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true',
|
|
help="Use FP32 computations in MoE top2 gating function")
|
|
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
|
help="policy for second expert, options: all/sampling/random")
|
|
parser.add_argument(
|
|
'--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
|
help=(
|
|
"whether to normalize gate probs before or after dropping experts "
|
|
"for capacity and randomization"
|
|
)
|
|
)
|
|
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
|
help="MoE Expert FFN dimension")
|
|
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
|
help="Use top1 gate instead of top2")
|
|
parser.add_argument(
|
|
'--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
|
help=(
|
|
"Fraction of tokens as capacity during validation"
|
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
|
)
|
|
)
|
|
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
|
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
|
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
|
help="Don't route padding tokens to any expert")
|
|
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
|
help="Enable X-Moe")
|
|
parser.add_argument('--freeze-moe', default=False, action='store_true',
|
|
help="Freeze MoE Params")
|
|
parser.add_argument('--deepnorm', default=False, action='store_true',
|
|
help="Enable DeepNorm")
|
|
parser.add_argument('--subln', default=False, action='store_true',
|
|
help="Enable SubLN")
|
|
parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='')
|
|
# args for pseudo-MoE layers
|
|
parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0,
|
|
help="FFN embed dim of alternate pseudo-MoE blocks")
|
|
parser.add_argument('--rel-pos-buckets', type=int, default=0,
|
|
help='')
|
|
parser.add_argument('--max-rel-pos', type=int, default=0,
|
|
help='')
|
|
# fmt: on
|
|
|
|
@classmethod
|
|
def build_model(cls, args, task):
|
|
"""Build a new model instance."""
|
|
|
|
# make sure all arguments are present in older models
|
|
base_architecture(args)
|
|
|
|
if getattr(args, "max_source_positions", None) is None:
|
|
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
|
if getattr(args, "max_target_positions", None) is None:
|
|
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
|
|
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
|
|
|
if args.share_all_embeddings:
|
|
if src_dict != tgt_dict:
|
|
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
|
if args.encoder_embed_dim != args.decoder_embed_dim:
|
|
raise ValueError(
|
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
|
)
|
|
if args.decoder_embed_path and (
|
|
args.decoder_embed_path != args.encoder_embed_path
|
|
):
|
|
raise ValueError(
|
|
"--share-all-embeddings not compatible with --decoder-embed-path"
|
|
)
|
|
encoder_embed_tokens = cls.build_embedding(
|
|
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
|
)
|
|
decoder_embed_tokens = encoder_embed_tokens
|
|
args.share_decoder_input_output_embed = True
|
|
else:
|
|
encoder_embed_tokens = cls.build_embedding(
|
|
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
|
)
|
|
decoder_embed_tokens = cls.build_embedding(
|
|
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
|
)
|
|
if getattr(args, "offload_activations", False):
|
|
args.checkpoint_activations = True # offloading implies checkpointing
|
|
|
|
encoder_embed_positions = (
|
|
PositionalEmbedding(
|
|
args.max_source_positions,
|
|
args.encoder_embed_dim,
|
|
src_dict.pad(),
|
|
learned=args.encoder_learned_pos,
|
|
)
|
|
if not args.no_token_positional_embeddings
|
|
else None
|
|
)
|
|
|
|
decoder_embed_positions = (
|
|
PositionalEmbedding(
|
|
args.max_target_positions,
|
|
args.decoder_embed_dim,
|
|
tgt_dict.pad(),
|
|
learned=args.decoder_learned_pos,
|
|
)
|
|
if not args.no_token_positional_embeddings
|
|
else None
|
|
)
|
|
|
|
if args.share_decoder_input_output_embed:
|
|
output_projection = torch.nn.Linear(
|
|
decoder_embed_tokens.weight.shape[1],
|
|
decoder_embed_tokens.weight.shape[0],
|
|
bias=False,
|
|
)
|
|
output_projection.weight = decoder_embed_tokens.weight
|
|
else:
|
|
output_projection = torch.nn.Linear(
|
|
args.decoder_embed_dim, len(tgt_dict), bias=False
|
|
)
|
|
torch.nn.init.normal_(
|
|
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
|
)
|
|
|
|
encoder = cls.build_encoder(
|
|
args,
|
|
encoder_embed_tokens,
|
|
encoder_embed_positions,
|
|
src_dict,
|
|
)
|
|
decoder = cls.build_decoder(
|
|
args,
|
|
decoder_embed_tokens,
|
|
decoder_embed_positions,
|
|
output_projection,
|
|
tgt_dict,
|
|
)
|
|
|
|
if not args.share_all_embeddings:
|
|
min_params_to_wrap = getattr(
|
|
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
|
)
|
|
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
|
|
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
|
|
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
|
|
return cls(args, encoder, decoder)
|
|
|
|
@classmethod
|
|
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
|
num_embeddings = len(dictionary)
|
|
padding_idx = dictionary.pad()
|
|
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
|
# if provided, load from preloaded dictionaries
|
|
if path:
|
|
embed_dict = utils.parse_embedding(path)
|
|
utils.load_embedding(embed_dict, dictionary, emb)
|
|
return emb
|
|
|
|
@classmethod
|
|
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
|
|
config = EncoderDecoderConfig()
|
|
config.override(args)
|
|
|
|
return MTEncoder(
|
|
config,
|
|
embed_tokens,
|
|
embed_positions,
|
|
is_encoder_decoder=True,
|
|
dictionary=dictionary,
|
|
)
|
|
|
|
@classmethod
|
|
def build_decoder(
|
|
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
|
):
|
|
config = EncoderDecoderConfig()
|
|
config.override(args)
|
|
|
|
return MTDecoder(
|
|
config,
|
|
embed_tokens,
|
|
embed_positions,
|
|
output_projection,
|
|
is_encoder_decoder=True,
|
|
dictionary=dictionary,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
src_tokens,
|
|
src_lengths,
|
|
prev_output_tokens,
|
|
return_all_hiddens: bool = False,
|
|
features_only: bool = False,
|
|
**kwargs
|
|
):
|
|
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
|
decoder_out = self.decoder(
|
|
prev_output_tokens,
|
|
encoder_out=encoder_out,
|
|
features_only=features_only,
|
|
return_all_hiddens=return_all_hiddens,
|
|
)
|
|
return decoder_out
|
|
|
|
def get_normalized_probs(
|
|
self,
|
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
|
log_probs: bool,
|
|
sample: Optional[Dict[str, Tensor]] = None,
|
|
):
|
|
"""Get normalized probabilities (or log probs) from a net's output."""
|
|
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
|
|
|
|
|
class MTEncoder(Encoder, FairseqEncoder):
|
|
def forward(self, src_tokens, **kwargs):
|
|
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
|
return super().forward(
|
|
src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs
|
|
)
|
|
|
|
def reorder_encoder_out(self, encoder_out, new_order):
|
|
new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order)
|
|
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
|
|
0, new_order
|
|
)
|
|
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(
|
|
0, new_order
|
|
)
|
|
|
|
encoder_states = encoder_out["encoder_states"]
|
|
if len(encoder_states) > 0:
|
|
for idx, state in enumerate(encoder_states):
|
|
encoder_states[idx] = state.index_select(0, new_order)
|
|
|
|
return {
|
|
"encoder_out": new_encoder_out, # T x B x C
|
|
"encoder_padding_mask": new_encoder_padding_mask,
|
|
"encoder_embedding": new_encoder_embedding, # B x T x C
|
|
"encoder_states": encoder_states, # List[T x B x C]
|
|
}
|
|
|
|
def max_positions(self):
|
|
return self.embed_positions.max_positions
|
|
|
|
|
|
@register_model_architecture("mt", "mt_base")
|
|
def base_architecture(args):
|
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
|
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
|
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
|
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
|
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
|
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
|
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
|
args.decoder_ffn_embed_dim = getattr(
|
|
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
|
)
|
|
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
|
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
|
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
|
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
|
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
|
args.activation_fn = getattr(args, "activation_fn", "relu")
|
|
args.dropout = getattr(args, "dropout", 0.1)
|
|
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
|
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
|
args.share_decoder_input_output_embed = getattr(
|
|
args, "share_decoder_input_output_embed", False
|
|
)
|
|
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
|
args.no_token_positional_embeddings = getattr(
|
|
args, "no_token_positional_embeddings", False
|
|
)
|
|
args.adaptive_input = getattr(args, "adaptive_input", False)
|
|
args.no_cross_attention = getattr(args, "no_cross_attention", False)
|
|
args.cross_self_attention = getattr(args, "cross_self_attention", False)
|
|
|
|
args.decoder_output_dim = getattr(
|
|
args, "decoder_output_dim", args.decoder_embed_dim
|
|
)
|
|
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
|
|
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
|
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
|
args.offload_activations = getattr(args, "offload_activations", False)
|
|
if args.offload_activations:
|
|
args.checkpoint_activations = True
|
|
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
|
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
|
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
|
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
|
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
|
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
|
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
|
args.is_moe = getattr(args, "is_moe", False)
|
|
args.selected_expert_count = getattr(args, "selected_expert_count", 2)
|