# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from typing import Dict, List, Optional, Tuple import torch from fairseq import distributed_utils, utils from fairseq.distributed import fsdp_wrap from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, register_model, register_model_architecture, ) from fairseq.models.transformer import Embedding from fairseq.modules import PositionalEmbedding from torch import Tensor from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig from torchscale.architecture.encoder import Encoder from .language_modeling import LMDecoder as MTDecoder logger = logging.getLogger(__name__) DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) @register_model("mt") class TranslationModel(FairseqEncoderDecoderModel): def __init__(self, args, encoder, decoder): super().__init__(encoder, decoder) self.args = args @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--activation-fn', choices=utils.get_available_activation_fns(), help='activation function to use') parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') parser.add_argument('--attention-dropout', type=float, metavar='D', help='dropout probability for attention weights') parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', help='dropout probability after activation in FFN.') parser.add_argument('--encoder-embed-path', type=str, metavar='STR', help='path to pre-trained encoder embedding') parser.add_argument('--encoder-embed-dim', type=int, metavar='N', help='encoder embedding dimension') parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', help='encoder embedding dimension for FFN') parser.add_argument('--encoder-layers', type=int, metavar='N', help='num encoder layers') parser.add_argument('--encoder-attention-heads', type=int, metavar='N', help='num encoder attention heads') parser.add_argument('--encoder-normalize-before', action='store_true', help='apply layernorm before each encoder block') parser.add_argument('--encoder-learned-pos', action='store_true', help='use learned positional embeddings in the encoder') parser.add_argument('--decoder-embed-path', type=str, metavar='STR', help='path to pre-trained decoder embedding') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', help='decoder embedding dimension for FFN') parser.add_argument('--decoder-layers', type=int, metavar='N', help='num decoder layers') parser.add_argument('--decoder-attention-heads', type=int, metavar='N', help='num decoder attention heads') parser.add_argument('--decoder-learned-pos', action='store_true', help='use learned positional embeddings in the decoder') parser.add_argument('--decoder-normalize-before', action='store_true', help='apply layernorm before each decoder block') parser.add_argument('--decoder-output-dim', type=int, metavar='N', help='decoder output dimension (extra linear layer ' 'if different from decoder embed dim') parser.add_argument('--share-decoder-input-output-embed', action='store_true', help='share decoder input and output embeddings') parser.add_argument('--share-all-embeddings', action='store_true', help='share encoder, decoder and output embeddings' ' (requires shared dictionary and embed dim)') parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', help='if set, disables positional embeddings (outside self attention)') parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion'), parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', help='sets adaptive softmax dropout for the tail projections') parser.add_argument('--layernorm-embedding', action='store_true', help='add layernorm to embedding') parser.add_argument('--no-scale-embedding', action='store_true', help='if True, dont scale embeddings') parser.add_argument('--checkpoint-activations', action='store_true', help='checkpoint activations at each layer, which saves GPU ' 'memory usage at the cost of some additional compute') parser.add_argument('--offload-activations', action='store_true', help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) parser.add_argument('--no-cross-attention', default=False, action='store_true', help='do not perform cross-attention') parser.add_argument('--cross-self-attention', default=False, action='store_true', help='perform cross+self-attention') # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, help='LayerDrop probability for encoder') parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, help='LayerDrop probability for decoder') parser.add_argument('--encoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') parser.add_argument('--decoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, help='iterative PQ quantization noise at training time') parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') # args for Fully Sharded Data Parallel (FSDP) training parser.add_argument( '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, help=( 'minimum number of params for a layer to be wrapped with FSDP() when ' 'training with --ddp-backend=fully_sharded. Smaller values will ' 'improve memory efficiency, but may make torch.distributed ' 'communication less efficient due to smaller input sizes. This option ' 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' '--offload-activations are passed.' ) ) # args for mixture-of-expert layers parser.add_argument('--moe-freq', type=int, metavar='D', default=0, help='Frequency at which we insert MoE Transformer layers') parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0, help='Frequency at which we insert MoE Transformer encoder layers') parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0, help='Frequency at which we insert MoE Transformer decoder layers') parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0, help='Number of experts in each MoE Layer') parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true', help="Use FP32 computations in MoE top2 gating function") parser.add_argument('--moe-second-expert-policy', type=str, default='sampling', help="policy for second expert, options: all/sampling/random") parser.add_argument( '--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', help=( "whether to normalize gate probs before or after dropping experts " "for capacity and randomization" ) ) parser.add_argument('--moe-expert-ffn-dim', type=int, default=0, help="MoE Expert FFN dimension") parser.add_argument('--moe-top1-expert', default=False, action='store_true', help="Use top1 gate instead of top2") parser.add_argument( '--moe-eval-capacity-token-fraction', type=float, default=0.25, help=( "Fraction of tokens as capacity during validation" "if set to negative, use same as training. range: (0.0, 1.0]." ) ) parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size', help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'") parser.add_argument('--use-moe-pad-mask', default=False, action='store_true', help="Don't route padding tokens to any expert") parser.add_argument('--use-xmoe', default=False, action='store_true', help="Enable X-Moe") parser.add_argument('--freeze-moe', default=False, action='store_true', help="Freeze MoE Params") parser.add_argument('--deepnorm', default=False, action='store_true', help="Enable DeepNorm") parser.add_argument('--subln', default=False, action='store_true', help="Enable SubLN") parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='') # args for pseudo-MoE layers parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0, help="FFN embed dim of alternate pseudo-MoE blocks") parser.add_argument('--rel-pos-buckets', type=int, default=0, help='') parser.add_argument('--max-rel-pos', type=int, default=0, help='') # fmt: on @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS args.ddp_rank = distributed_utils.get_data_parallel_rank() src_dict, tgt_dict = task.source_dictionary, task.target_dictionary if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing encoder_embed_positions = ( PositionalEmbedding( args.max_source_positions, args.encoder_embed_dim, src_dict.pad(), learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) decoder_embed_positions = ( PositionalEmbedding( args.max_target_positions, args.decoder_embed_dim, tgt_dict.pad(), learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if args.share_decoder_input_output_embed: output_projection = torch.nn.Linear( decoder_embed_tokens.weight.shape[1], decoder_embed_tokens.weight.shape[0], bias=False, ) output_projection.weight = decoder_embed_tokens.weight else: output_projection = torch.nn.Linear( args.decoder_embed_dim, len(tgt_dict), bias=False ) torch.nn.init.normal_( output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5 ) encoder = cls.build_encoder( args, encoder_embed_tokens, encoder_embed_positions, src_dict, ) decoder = cls.build_decoder( args, decoder_embed_tokens, decoder_embed_positions, output_projection, tgt_dict, ) if not args.share_all_embeddings: min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb @classmethod def build_encoder(cls, args, embed_tokens, embed_positions, dictionary): config = EncoderDecoderConfig() config.override(args) return MTEncoder( config, embed_tokens, embed_positions, is_encoder_decoder=True, dictionary=dictionary, ) @classmethod def build_decoder( cls, args, embed_tokens, embed_positions, output_projection, dictionary ): config = EncoderDecoderConfig() config.override(args) return MTDecoder( config, embed_tokens, embed_positions, output_projection, is_encoder_decoder=True, dictionary=dictionary, ) def forward( self, src_tokens, src_lengths, prev_output_tokens, return_all_hiddens: bool = False, features_only: bool = False, **kwargs ): encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, return_all_hiddens=return_all_hiddens, ) return decoder_out def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample) class MTEncoder(Encoder, FairseqEncoder): def forward(self, src_tokens, **kwargs): self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) return super().forward( src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs ) def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order) new_encoder_embedding = encoder_out["encoder_embedding"].index_select( 0, new_order ) new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select( 0, new_order ) encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(0, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] } def max_positions(self): return self.embed_positions.max_positions @register_model_architecture("mt", "mt_base") def base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.no_cross_attention = getattr(args, "no_cross_attention", False) args.cross_self_attention = getattr(args, "cross_self_attention", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) args.is_moe = getattr(args, "is_moe", False) args.selected_expert_count = getattr(args, "selected_expert_count", 2)