torchscale/examples/fairseq/models/machine_translation.py
2022-11-23 08:36:55 -08:00

454 lines
22 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.distributed import utils as dist_utils, fsdp_wrap
from fairseq import distributed_utils
from fairseq import checkpoint_utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding
from fairseq.modules import (
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder
from torch import Tensor
import logging
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("mt")
class TranslationModel(FairseqEncoderDecoderModel):
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension (extra linear layer '
'if different from decoder embed dim')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
parser.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activations at each layer, which saves GPU '
'memory usage at the cost of some additional compute')
parser.add_argument('--offload-activations', action='store_true',
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
help='iterative PQ quantization noise at training time')
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
'minimum number of params for a layer to be wrapped with FSDP() when '
'training with --ddp-backend=fully_sharded. Smaller values will '
'improve memory efficiency, but may make torch.distributed '
'communication less efficient due to smaller input sizes. This option '
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
'--offload-activations are passed.'
)
)
# args for mixture-of-expert layers
parser.add_argument('--moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer layers')
parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer encoder layers')
parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer decoder layers')
parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0,
help='Number of experts in each MoE Layer')
parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true',
help="Use FP32 computations in MoE top2 gating function")
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
help="policy for second expert, options: all/sampling/random")
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
help="MoE Expert FFN dimension")
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
help="Use top1 gate instead of top2")
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
help="Fraction of tokens as capacity during validation" + \
"if set to negative, use same as training. range: (0.0, 1.0].")
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
help="Don't route padding tokens to any expert")
parser.add_argument('--use-xmoe', default=False, action='store_true',
help="Enable X-Moe")
parser.add_argument('--freeze-moe', default=False, action='store_true',
help="Freeze MoE Params")
parser.add_argument('--deepnorm', default=False, action='store_true',
help="Enable DeepNorm")
parser.add_argument('--subln', default=False, action='store_true',
help="Enable SubLN")
parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='')
# args for pseudo-MoE layers
parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0,
help="FFN embed dim of alternate pseudo-MoE blocks")
parser.add_argument('--rel-pos-buckets', type=int, default=0,
help='')
parser.add_argument('--max-rel-pos', type=int, default=0,
help='')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
args.ddp_rank = distributed_utils.get_data_parallel_rank()
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = cls.build_embedding(
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder_embed_positions = (
PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim,
src_dict.pad(),
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
decoder_embed_positions = (
PositionalEmbedding(
args.max_target_positions,
args.decoder_embed_dim,
tgt_dict.pad(),
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
decoder_embed_tokens.weight.shape[1],
decoder_embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = decoder_embed_tokens.weight
else:
output_projection = torch.nn.Linear(
decoder_embed_dim, len(tgt_dict), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
)
encoder = cls.build_encoder(
args,
encoder_embed_tokens,
encoder_embed_positions,
src_dict,
)
decoder = cls.build_decoder(
args,
decoder_embed_tokens,
decoder_embed_positions,
output_projection,
tgt_dict,
)
if not args.share_all_embeddings:
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
return cls(args, encoder, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
config = EncoderConfig()
config.override(args)
return MTEncoder(
config,
embed_tokens,
embed_positions,
is_encoder_decoder=True,
dictionary=dictionary,
)
@classmethod
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
config = DecoderConfig()
config.override(args)
return MTDecoder(
config,
embed_tokens,
embed_positions,
output_projection,
is_encoder_decoder=True,
dictionary=dictionary,
)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = False,
features_only: bool = False,
**kwargs
):
encoder_out = self.encoder(
src_tokens,
return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
class MTEncoder(Encoder, FairseqEncoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask,
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
}
def max_positions(self):
return self.embed_positions.max_positions
@register_model_architecture("mt", "mt_base")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.is_moe = getattr(args, "is_moe", False)
args.selected_expert_count = getattr(args, "selected_expert_count", 2)