torchscale/examples/fairseq/models/machine_translation.py

454 lines
22 KiB
Python
Raw Permalink Normal View History

2022-11-23 16:36:55 +00:00
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
2022-11-23 16:21:58 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
2022-11-26 17:01:02 +00:00
import logging
2022-11-26 16:10:15 +00:00
from typing import Dict, List, Optional, Tuple
2022-11-23 16:21:58 +00:00
import torch
2022-11-26 17:01:02 +00:00
from fairseq import distributed_utils, utils
2022-12-08 12:20:27 +00:00
from fairseq.distributed import fsdp_wrap
2022-11-23 16:21:58 +00:00
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding
2022-11-26 16:10:15 +00:00
from fairseq.modules import PositionalEmbedding
2022-11-26 17:01:02 +00:00
from torch import Tensor
2023-07-31 16:17:03 +00:00
from torchscale.architecture.config import DecoderConfig, EncoderConfig, EncoderDecoderConfig
2022-11-23 16:21:58 +00:00
from torchscale.architecture.encoder import Encoder
2022-11-26 17:01:02 +00:00
2022-11-23 16:21:58 +00:00
from .language_modeling import LMDecoder as MTDecoder
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("mt")
class TranslationModel(FairseqEncoderDecoderModel):
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after activation in FFN.')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--decoder-output-dim', type=int, metavar='N',
help='decoder output dimension (extra linear layer '
'if different from decoder embed dim')
parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
parser.add_argument('--checkpoint-activations', action='store_true',
help='checkpoint activations at each layer, which saves GPU '
'memory usage at the cost of some additional compute')
parser.add_argument('--offload-activations', action='store_true',
help='checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations.')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0,
help='iterative PQ quantization noise at training time')
parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8,
help='block size of quantization noise at training time')
parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0,
help='scalar quantization noise and scalar quantization at training time')
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
'--min-params-to-wrap', type=int, metavar='D', default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
'minimum number of params for a layer to be wrapped with FSDP() when '
'training with --ddp-backend=fully_sharded. Smaller values will '
'improve memory efficiency, but may make torch.distributed '
'communication less efficient due to smaller input sizes. This option '
'is set to 0 (i.e., always wrap) when --checkpoint-activations or '
'--offload-activations are passed.'
)
)
# args for mixture-of-expert layers
parser.add_argument('--moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer layers')
parser.add_argument('--encoder-moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer encoder layers')
parser.add_argument('--decoder-moe-freq', type=int, metavar='D', default=0,
help='Frequency at which we insert MoE Transformer decoder layers')
parser.add_argument('--moe-expert-count', type=int, metavar='D', default=0,
help='Number of experts in each MoE Layer')
parser.add_argument('--moe-gating-use-fp32', default=False, action='store_true',
help="Use FP32 computations in MoE top2 gating function")
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
help="policy for second expert, options: all/sampling/random")
2022-11-26 16:10:15 +00:00
parser.add_argument(
'--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help=(
"whether to normalize gate probs before or after dropping experts "
"for capacity and randomization"
)
)
2022-11-23 16:21:58 +00:00
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
help="MoE Expert FFN dimension")
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
help="Use top1 gate instead of top2")
2022-11-26 16:10:15 +00:00
parser.add_argument(
'--moe-eval-capacity-token-fraction', type=float, default=0.25,
help=(
"Fraction of tokens as capacity during validation"
"if set to negative, use same as training. range: (0.0, 1.0]."
)
)
2022-11-23 16:21:58 +00:00
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
help="Don't route padding tokens to any expert")
parser.add_argument('--use-xmoe', default=False, action='store_true',
help="Enable X-Moe")
parser.add_argument('--freeze-moe', default=False, action='store_true',
help="Freeze MoE Params")
parser.add_argument('--deepnorm', default=False, action='store_true',
help="Enable DeepNorm")
parser.add_argument('--subln', default=False, action='store_true',
help="Enable SubLN")
parser.add_argument('--pretrained-dense-mt-model-path', type=str, default='')
# args for pseudo-MoE layers
parser.add_argument('--alternate-ffn-embed-dim', type=int, default=0,
help="FFN embed dim of alternate pseudo-MoE blocks")
parser.add_argument('--rel-pos-buckets', type=int, default=0,
help='')
parser.add_argument('--max-rel-pos', type=int, default=0,
help='')
# fmt: on
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
args.ddp_rank = distributed_utils.get_data_parallel_rank()
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = cls.build_embedding(
args, src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = cls.build_embedding(
args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
encoder_embed_positions = (
PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim,
src_dict.pad(),
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
decoder_embed_positions = (
PositionalEmbedding(
args.max_target_positions,
args.decoder_embed_dim,
tgt_dict.pad(),
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
decoder_embed_tokens.weight.shape[1],
decoder_embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = decoder_embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, len(tgt_dict), bias=False
2022-11-23 16:21:58 +00:00
)
torch.nn.init.normal_(
2022-11-26 17:01:02 +00:00
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
2022-11-23 16:21:58 +00:00
)
encoder = cls.build_encoder(
args,
2022-11-26 16:10:15 +00:00
encoder_embed_tokens,
2022-11-23 16:21:58 +00:00
encoder_embed_positions,
2022-11-26 16:10:15 +00:00
src_dict,
2022-11-23 16:21:58 +00:00
)
decoder = cls.build_decoder(
2022-11-26 16:10:15 +00:00
args,
2022-11-23 16:21:58 +00:00
decoder_embed_tokens,
decoder_embed_positions,
output_projection,
tgt_dict,
)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
if not args.share_all_embeddings:
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
# fsdp_wrap is a no-op when --ddp-backend != fully_sharded
encoder = fsdp_wrap(encoder, min_num_params=min_params_to_wrap)
decoder = fsdp_wrap(decoder, min_num_params=min_params_to_wrap)
return cls(args, encoder, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
@classmethod
def build_encoder(cls, args, embed_tokens, embed_positions, dictionary):
2023-07-31 16:17:03 +00:00
config = EncoderDecoderConfig()
2022-11-23 16:21:58 +00:00
config.override(args)
return MTEncoder(
2022-11-26 16:10:15 +00:00
config,
embed_tokens,
embed_positions,
2022-11-23 16:21:58 +00:00
is_encoder_decoder=True,
dictionary=dictionary,
)
@classmethod
2022-11-26 17:01:02 +00:00
def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary
):
2023-07-31 16:17:03 +00:00
config = EncoderDecoderConfig()
2022-11-23 16:21:58 +00:00
config.override(args)
return MTDecoder(
2022-11-26 16:10:15 +00:00
config,
embed_tokens,
2022-11-23 16:21:58 +00:00
embed_positions,
output_projection,
is_encoder_decoder=True,
dictionary=dictionary,
)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
return_all_hiddens: bool = False,
features_only: bool = False,
**kwargs
):
2022-11-26 17:01:02 +00:00
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
2022-11-23 16:21:58 +00:00
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
return_all_hiddens=return_all_hiddens,
)
return decoder_out
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
class MTEncoder(Encoder, FairseqEncoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
2022-11-26 17:01:02 +00:00
return super().forward(
src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs
)
2022-11-23 16:21:58 +00:00
def reorder_encoder_out(self, encoder_out, new_order):
2023-01-05 09:19:51 +00:00
new_encoder_out = encoder_out["encoder_out"].index_select(0, new_order)
2022-11-26 17:01:02 +00:00
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
0, new_order
)
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(
0, new_order
)
2022-11-23 16:21:58 +00:00
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
2023-01-05 09:19:51 +00:00
encoder_states[idx] = state.index_select(0, new_order)
2022-11-23 16:21:58 +00:00
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask,
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
}
def max_positions(self):
return self.embed_positions.max_positions
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
@register_model_architecture("mt", "mt_base")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
args.is_moe = getattr(args, "is_moe", False)
args.selected_expert_count = getattr(args, "selected_expert_count", 2)