# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import functools import math from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn from fairseq import utils from fairseq.distributed import utils as dist_utils, fsdp_wrap from fairseq import distributed_utils from fairseq import checkpoint_utils from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, FairseqIncrementalDecoder, register_model, register_model_architecture, ) from fairseq.models.transformer import Embedding from fairseq.modules import ( AdaptiveSoftmax, FairseqDropout, LayerDropModuleList, LayerNorm, PositionalEmbedding, SinusoidalPositionalEmbedding, ) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from torchscale.architecture.encoder import Encoder from torchscale.architecture.config import EncoderConfig, DecoderConfig from .language_modeling import LMDecoder as MTDecoder from torch import Tensor import logging logger = logging.getLogger(__name__) DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) @register_model("mt") class TranslationModel(FairseqEncoderDecoderModel): def __init__(self, args, encoder, decoder): super().__init__(encoder, decoder) self.args = args @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument('--activation-fn', choices=utils.get_available_activation_fns(), help='activation function to use') parser.add_argument('--dropout', type=float, metavar='D', help='dropout probability') parser.add_argument('--attention-dropout', type=float, metavar='D', help='dropout probability for attention weights') parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', help='dropout probability after activation in FFN.') parser.add_argument('--encoder-embed-path', type=str, metavar='STR', help='path to pre-trained encoder embedding') parser.add_argument('--encoder-embed-dim', type=int, metavar='N', help='encoder embedding dimension') parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', help='encoder embedding dimension for FFN') parser.add_argument('--encoder-layers', type=int, metavar='N', help='num encoder layers') parser.add_argument('--encoder-attention-heads', type=int, metavar='N', help='num encoder attention heads') parser.add_argument('--encoder-normalize-before', action='store_true', help='apply layernorm before each encoder block') parser.add_argument('--encoder-learned-pos', action='store_true', help='use learned positional embeddings in the encoder') parser.add_argument('--decoder-embed-path', type=str, metavar='STR', help='path to pre-trained decoder embedding') parser.add_argument('--decoder-embed-dim', type=int, metavar='N', help='decoder embedding dimension') parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', help='decoder embedding dimension for FFN') parser.add_argument('--decoder-layers', type=int, metavar='N', help='num decoder layers') parser.add_argument('--decoder-attention-heads', type=int, metavar='N', help='num decoder attention heads') parser.add_argument('--decoder-learned-pos', action='store_true', help='use learned positional embeddings in the decoder') parser.add_argument('--decoder-normalize-before', action='store_true', help='apply layernorm before each decoder block') parser.add_argument('--decoder-output-dim', type=int, metavar='N', help='decoder output dimension (extra linear layer ' 'if different from decoder embed dim') parser.add_argument('--share-decoder-input-output-embed', action='store_true', help='share decoder input and output embeddings') parser.add_argument('--share-all-embeddings', action='store_true', help='share encoder, decoder and output embeddings' ' (requires shared dictionary and embed dim)') parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', help='if set, disables positional embeddings (outside self attention)') parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', help='comma separated list of adaptive softmax cutoff points. ' 'Must be used with adaptive_loss criterion'), parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', help='sets adaptive softmax dropout for the tail projections') parser.add_argument('--layernorm-embedding', action='store_true', help='add layernorm to embedding') parser.add_argument('--no-scale-embedding', action='store_true', help='if True, dont scale embeddings') parser.add_argument('--checkpoint-activations', action='store_true', help='checkpoint activations at each layer, which saves GPU ' 'memory usage at the cost of some additional compute') parser.add_argument('--offload-activations', action='store_true', help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.') # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) parser.add_argument('--no-cross-attention', default=False, action='store_true', help='do not perform cross-attention') parser.add_argument('--cross-self-attention', default=False, action='store_true', help='perform cross+self-attention') # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, help='LayerDrop probability for encoder') parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, help='LayerDrop probability for decoder') parser.add_argument('--encoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') parser.add_argument('--decoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, help='iterative PQ quantization noise at training time') parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, help='block size of quantization noise at training time') parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, help='scalar quantization noise and scalar quantization at training time') # args for Fully Sharded Data Parallel (FSDP) training parser.add_argument( '--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP, help=( 'minimum number of params for a layer to be wrapped with FSDP() when ' 'training with --ddp-backend=fully_sharded. Smaller values will ' 'improve memory efficiency, but may make torch.distributed ' 'communication less efficient due to smaller input sizes. This option ' 'is set to 0 (i.e., always wrap) when --checkpoint-activations or ' '--offload-activations are passed.' ) ) # args for mixture-of-expert layers parser.add_argument('--moe-freq', type=int, metavar='D', default=0, help='Frequency at which we insert MoE Transformer layers') parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0, help='Frequency at which we insert MoE Transformer encoder layers') parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0, help='Frequency at which we insert MoE Transformer decoder layers') parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0, help='Number of experts in each MoE Layer') parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true', help="Use FP32 computations in MoE top2 gating function") parser.add_argument('--moe-second-expert-policy', type=str, default='sampling', help="policy for second expert, options: all/sampling/random") parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', help="whether to normalize gate probs before or after dropping experts for capacity and randomization") parser.add_argument('--moe-expert-ffn-dim', type=int, default=0, help="MoE Expert FFN dimension") parser.add_argument('--moe-top1-expert', default=False, action='store_true', help="Use top1 gate instead of top2") parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25, help="Fraction of tokens as capacity during validation" + \ "if set to negative, use same as training. range: (0.0, 1.0].") parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size', help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'") parser.add_argument('--use-moe-pad-mask', default=False, action='store_true', help="Don't route padding tokens to any expert") parser.add_argument('--use-xmoe', default=False, action='store_true', help="Enable X-Moe") parser.add_argument('--freeze-moe', default=False, action='store_true', help="Freeze MoE Params") parser.add_argument('--deepnorm', default=False, action='store_true', help="Enable DeepNorm") parser.add_argument('--subln', default=False, action='store_true', help="Enable SubLN") parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='') # args for pseudo-MoE layers parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0, help="FFN embed dim of alternate pseudo-MoE blocks") parser.add_argument('--rel-pos-buckets', type=int, default=0, help='') parser.add_argument('--max-rel-pos', type=int, default=0, help='') # fmt: on @classmethod def build_model(cls, args, task): """Build a new model instance.""" # make sure all arguments are present in older models base_architecture(args) if getattr(args, "max_source_positions", None) is None: args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS args.ddp_rank = distributed_utils.get_data_parallel_rank() src_dict, tgt_dict = task.source_dictionary, task.target_dictionary if args.share_all_embeddings: if src_dict != tgt_dict: raise ValueError("--share-all-embeddings requires a joined dictionary") if args.encoder_embed_dim != args.decoder_embed_dim: raise ValueError( "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" ) if args.decoder_embed_path and ( args.decoder_embed_path != args.encoder_embed_path ): raise ValueError( "--share-all-embeddings not compatible with --decoder-embed-path" ) encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = encoder_embed_tokens args.share_decoder_input_output_embed = True else: encoder_embed_tokens = cls.build_embedding( args, src_dict, args.encoder_embed_dim, args.encoder_embed_path ) decoder_embed_tokens = cls.build_embedding( args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path ) if getattr(args, "offload_activations", False): args.checkpoint_activations = True # offloading implies checkpointing encoder_embed_positions = ( PositionalEmbedding( args.max_source_positions, args.encoder_embed_dim, src_dict.pad(), learned=args.encoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) decoder_embed_positions = ( PositionalEmbedding( args.max_target_positions, args.decoder_embed_dim, tgt_dict.pad(), learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if args.share_decoder_input_output_embed: output_projection = torch.nn.Linear( decoder_embed_tokens.weight.shape[1], decoder_embed_tokens.weight.shape[0], bias=False, ) output_projection.weight = decoder_embed_tokens.weight else: output_projection = torch.nn.Linear( decoder_embed_dim, len(tgt_dict), bias=False ) torch.nn.init.normal_( output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5 ) encoder = cls.build_encoder( args, encoder_embed_tokens, encoder_embed_positions, src_dict, ) decoder = cls.build_decoder( args, decoder_embed_tokens, decoder_embed_positions, output_projection, tgt_dict, ) if not args.share_all_embeddings: min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP ) # fsdp_wrap is a no-op when --ddp-backend != fully_sharded encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap) decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap) return cls(args, encoder, decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb @classmethod def build_encoder(cls, args, embed_tokens, embed_positions, dictionary): config = EncoderConfig() config.override(args) return MTEncoder( config, embed_tokens, embed_positions, is_encoder_decoder=True, dictionary=dictionary, ) @classmethod def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary): config = DecoderConfig() config.override(args) return MTDecoder( config, embed_tokens, embed_positions, output_projection, is_encoder_decoder=True, dictionary=dictionary, ) def forward( self, src_tokens, src_lengths, prev_output_tokens, return_all_hiddens: bool = False, features_only: bool = False, **kwargs ): encoder_out = self.encoder( src_tokens, return_all_hiddens=return_all_hiddens ) decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, return_all_hiddens=return_all_hiddens, ) return decoder_out def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): """Get normalized probabilities (or log probs) from a net's output.""" return self.get_normalized_probs_scriptable(net_output, log_probs, sample) class MTEncoder(Encoder, FairseqEncoder): def forward(self, src_tokens, **kwargs): self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs) def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order) new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order) new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order) encoder_states = encoder_out["encoder_states"] if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) return { "encoder_out": new_encoder_out, # T x B x C "encoder_padding_mask": new_encoder_padding_mask, "encoder_embedding": new_encoder_embedding, # B x T x C "encoder_states": encoder_states, # List[T x B x C] } def max_positions(self): return self.embed_positions.max_positions @register_model_architecture("mt", "mt_base") def base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) args.encoder_layers = getattr(args, "encoder_layers", 6) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) args.decoder_embed_path = getattr(args, "decoder_embed_path", None) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) args.decoder_ffn_embed_dim = getattr( args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim ) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.activation_dropout = getattr(args, "activation_dropout", 0.0) args.activation_fn = getattr(args, "activation_fn", "relu") args.dropout = getattr(args, "dropout", 0.1) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.adaptive_input = getattr(args, "adaptive_input", False) args.no_cross_attention = getattr(args, "no_cross_attention", False) args.cross_self_attention = getattr(args, "cross_self_attention", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) args.is_moe = getattr(args, "is_moe", False) args.selected_expert_count = getattr(args, "selected_expert_count", 2)