torchscale/examples/fairseq/models/language_modeling.py

357 lines
13 KiB
Python
Raw Normal View History

2022-11-23 16:36:55 +00:00
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
2022-11-23 16:21:58 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
2022-11-26 16:10:15 +00:00
import logging
2022-11-23 16:21:58 +00:00
from dataclasses import dataclass, field
from typing import Optional
2022-11-26 17:01:02 +00:00
import torch
from fairseq import distributed_utils, utils
2022-11-23 16:21:58 +00:00
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
2022-11-26 17:01:02 +00:00
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
2022-11-23 16:21:58 +00:00
from fairseq.modules import PositionalEmbedding
from omegaconf import II
2022-11-26 17:01:02 +00:00
from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
2022-11-23 16:21:58 +00:00
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
attention_dropout: float = field(
default=0.0, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
relu_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
decoder_input_dim: int = field(
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_attention_heads: int = field(
default=8, metadata={"help": "num decoder attention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each decoder block"}
)
no_token_positional_embeddings: bool = field(
default=False,
metadata={
"help": "if set, disables positional embeddings (outside self attention)"
},
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
2022-11-26 17:01:02 +00:00
},
2022-11-23 16:21:58 +00:00
)
moe_freq: int = field(
default=0,
2022-11-26 17:01:02 +00:00
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
2022-11-23 16:21:58 +00:00
)
moe_expert_count: int = field(
2022-11-26 17:01:02 +00:00
default=0, metadata={"help": "Number of experts in each MoE Layer"}
2022-11-23 16:21:58 +00:00
)
moe_gating_use_fp32: bool = field(
default=False,
2022-11-26 17:01:02 +00:00
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
2022-11-23 16:21:58 +00:00
)
moe_second_expert_policy: str = field(
2022-11-26 17:01:02 +00:00
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
2022-11-23 16:21:58 +00:00
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
2022-11-26 17:01:02 +00:00
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
2022-11-23 16:21:58 +00:00
)
moe_expert_ffn_dim: Optional[int] = field(
2022-11-26 17:01:02 +00:00
default=None, metadata={"help": "MoE expert FFN dimension"}
2022-11-23 16:21:58 +00:00
)
moe_top1_expert: Optional[bool] = field(
2022-11-26 17:01:02 +00:00
default=False, metadata={"help": "Use top1 gate instead of top2"}
2022-11-23 16:21:58 +00:00
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
2022-11-26 16:10:15 +00:00
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
2022-11-26 17:01:02 +00:00
},
2022-11-23 16:21:58 +00:00
)
moe_normalize_expert_grad: Optional[str] = field(
2022-11-26 17:01:02 +00:00
default="world_size",
2022-11-23 16:21:58 +00:00
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
2022-11-26 17:01:02 +00:00
},
2022-11-23 16:21:58 +00:00
)
record_a2a_perf_stats: Optional[bool] = field(
2022-11-26 17:01:02 +00:00
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
2022-11-23 16:21:58 +00:00
)
dummy_a2a: Optional[bool] = field(
2022-11-26 17:01:02 +00:00
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
2022-11-23 16:21:58 +00:00
)
moe_batch_prioritized_routing: Optional[bool] = field(
2022-11-26 17:01:02 +00:00
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
2022-11-23 16:21:58 +00:00
)
use_xmoe: Optional[bool] = field(
default=False,
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("common.tpu")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
fp16: bool = II("common.fp16")
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
ddp_backend: str = II("distributed_training.ddp_backend")
world_size: int = II("distributed_training.distributed_world_size")
distributed_rank: int = II("distributed_training.distributed_rank")
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
rel_pos_buckets: Optional[int] = field(
default=0,
)
max_rel_pos: Optional[int] = field(
default=0,
)
2023-01-04 06:54:24 +00:00
xpos_rel_pos: Optional[bool] = field(
default=False,
)
xpos_scale_base: Optional[int] = field(
default=512,
)
2022-11-23 16:21:58 +00:00
@register_model("lm", dataclass=LanguageConfig)
class LanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
embed_positions = (
PositionalEmbedding(
args.max_target_positions,
args.decoder_embed_dim,
task.dictionary.pad(),
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
2022-11-26 16:10:15 +00:00
args.decoder_embed_dim, len(task.dictionary), bias=False
2022-11-23 16:21:58 +00:00
)
torch.nn.init.normal_(
2022-11-26 17:01:02 +00:00
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
2022-11-23 16:21:58 +00:00
)
2022-11-26 17:01:02 +00:00
if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, "fp16", False)
and not getattr(args, "memory_efficient_fp16", False)
and getattr(args, "ddp_backend", None) != "fully_sharded"
2022-11-23 16:21:58 +00:00
):
2022-11-26 17:01:02 +00:00
assert (
args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = DecoderConfig()
config.override(args)
decoder = LMDecoder(
2022-11-26 16:10:15 +00:00
config,
embed_tokens,
2022-11-23 16:21:58 +00:00
embed_positions,
output_projection,
is_encoder_decoder=False,
dictionary=task.dictionary,
)
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
return cls(args, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return Embedding(len(dictionary), embed_dim, dictionary.pad())
class LMDecoder(Decoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
def max_positions(self):
return self.embed_positions.max_positions
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
2022-11-26 16:10:15 +00:00
2022-11-23 16:21:58 +00:00
@register_model_architecture("lm", "lm_base")
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True