# Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from dataclasses import dataclass, field from typing import Optional import torch from fairseq import distributed_utils, utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import ( FairseqIncrementalDecoder, FairseqLanguageModel, register_model, register_model_architecture, ) from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding from fairseq.modules import PositionalEmbedding from omegaconf import II from torchscale.architecture.config import DecoderConfig from torchscale.architecture.decoder import Decoder DEFAULT_MAX_TARGET_POSITIONS = 1024 logger = logging.getLogger(__name__) @dataclass class LanguageConfig(FairseqDataclass): activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( default="relu", metadata={"help": "activation function to use"} ) dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) attention_dropout: float = field( default=0.0, metadata={"help": "dropout probability for attention weights"} ) activation_dropout: float = field( default=0.0, metadata={"help": "dropout probability after activation in FFN."} ) relu_dropout: float = field( default=0.0, metadata={"help": "dropout probability after activation in FFN."} ) decoder_embed_dim: int = field( default=512, metadata={"help": "decoder embedding dimension"} ) decoder_output_dim: int = field( default=512, metadata={"help": "decoder output dimension"} ) decoder_input_dim: int = field( default=512, metadata={"help": "decoder input dimension"} ) decoder_ffn_embed_dim: int = field( default=2048, metadata={"help": "decoder embedding dimension for FFN"} ) decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) decoder_attention_heads: int = field( default=8, metadata={"help": "num decoder attention heads"} ) decoder_normalize_before: bool = field( default=False, metadata={"help": "apply layernorm before each decoder block"} ) no_token_positional_embeddings: bool = field( default=False, metadata={ "help": "if set, disables positional embeddings (outside self attention)" }, ) share_decoder_input_output_embed: bool = field( default=False, metadata={"help": "share decoder input and output embeddings"} ) decoder_learned_pos: bool = field( default=False, metadata={"help": "use learned positional embeddings in the decoder"}, ) layernorm_embedding: bool = field( default=False, metadata={"help": "add layernorm to embedding"} ) no_scale_embedding: bool = field( default=False, metadata={"help": "if True, dont scale embeddings"} ) checkpoint_activations: bool = field( default=False, metadata={"help": "checkpoint activations at each layer"} ) offload_activations: bool = field( default=False, metadata={"help": "move checkpointed activations to CPU after they are used."}, ) # config for Fully Sharded Data Parallel (FSDP) training min_params_to_wrap: int = field( default=DEFAULT_MIN_PARAMS_TO_WRAP, metadata={ "help": ( "minimum number of params for a layer to be wrapped with FSDP() when " "training with --ddp-backend=fully_sharded. Smaller values will " "improve memory efficiency, but may make torch.distributed " "communication less efficient due to smaller input sizes. This option " "is set to 0 (i.e., always wrap) when --checkpoint-activations or " "--offload-activations are passed." ) }, ) moe_freq: int = field( default=0, metadata={"help": "Frequency at which we insert MoE Transformer layers"}, ) moe_expert_count: int = field( default=0, metadata={"help": "Number of experts in each MoE Layer"} ) moe_gating_use_fp32: bool = field( default=False, metadata={"help": "Use FP32 computations in MoE top2 gating function"}, ) moe_second_expert_policy: str = field( default="sampling", metadata={"help": "policy for second expert, options: all/sampling/random"}, ) moe_normalize_gate_prob_before_dropping: bool = field( default=False, metadata={ "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization" }, ) moe_expert_ffn_dim: Optional[int] = field( default=None, metadata={"help": "MoE expert FFN dimension"} ) moe_top1_expert: Optional[bool] = field( default=False, metadata={"help": "Use top1 gate instead of top2"} ) moe_eval_capacity_token_fraction: Optional[float] = field( default=0.25, metadata={ "help": ( "Default: 0.25, Fraction of tokens as capacity during validation, " "if set to negative, use same as training. range: (0.0, 1.0]." ) }, ) moe_normalize_expert_grad: Optional[str] = field( default="world_size", metadata={ "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" }, ) record_a2a_perf_stats: Optional[bool] = field( default=False, metadata={"help": "records all to all perf stats during distributed training"}, ) dummy_a2a: Optional[bool] = field( default=False, metadata={ "help": "By passes all to all during distributed training by returning the input buffer as output" }, ) moe_batch_prioritized_routing: Optional[bool] = field( default=False, metadata={ "help": "if true orders token by the gate prob before capacity dropping." }, ) use_xmoe: Optional[bool] = field( default=False, ) # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") tpu: bool = II("common.tpu") memory_efficient_fp16: bool = II("common.memory_efficient_fp16") fp16: bool = II("common.fp16") fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads") ddp_backend: str = II("distributed_training.ddp_backend") world_size: int = II("distributed_training.distributed_world_size") distributed_rank: int = II("distributed_training.distributed_rank") ddp_rank: int = II("distributed_training.distributed_rank") deepnorm: Optional[bool] = field( default=False, ) subln: Optional[bool] = field( default=False, ) rel_pos_buckets: Optional[int] = field( default=0, ) max_rel_pos: Optional[int] = field( default=0, ) xpos_rel_pos: Optional[bool] = field( default=False, ) xpos_scale_base: Optional[int] = field( default=512, ) @register_model("lm", dataclass=LanguageConfig) class LanguageModel(FairseqLanguageModel): def __init__(self, args, decoder): self.args = args super().__init__(decoder) @classmethod def build_model(cls, args, task): if getattr(args, "max_target_positions", None) is None: args.max_target_positions = getattr( args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS ) embed_tokens = cls.build_embedding( args, task.source_dictionary, args.decoder_embed_dim ) embed_positions = ( PositionalEmbedding( args.max_target_positions, args.decoder_embed_dim, task.dictionary.pad(), learned=args.decoder_learned_pos, ) if not args.no_token_positional_embeddings else None ) if args.share_decoder_input_output_embed: output_projection = torch.nn.Linear( embed_tokens.weight.shape[1], embed_tokens.weight.shape[0], bias=False, ) output_projection.weight = embed_tokens.weight else: output_projection = torch.nn.Linear( args.decoder_embed_dim, len(task.dictionary), bias=False ) torch.nn.init.normal_( output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5 ) if getattr(args, "moe_freq", 0) > 0 and ( getattr(args, "fp16", False) and not getattr(args, "memory_efficient_fp16", False) and getattr(args, "ddp_backend", None) != "fully_sharded" ): assert ( args.fp16_no_flatten_grads ), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" args.ddp_rank = distributed_utils.get_data_parallel_rank() config = DecoderConfig() config.override(args) decoder = LMDecoder( config, embed_tokens, embed_positions, output_projection, is_encoder_decoder=False, dictionary=task.dictionary, ) return cls(args, decoder) @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): return Embedding(len(dictionary), embed_dim, dictionary.pad()) class LMDecoder(Decoder, FairseqIncrementalDecoder): def forward(self, src_tokens, **kwargs): self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) return super().forward(src_tokens, self_attn_padding_mask, **kwargs) def max_positions(self): return self.embed_positions.max_positions def reorder_incremental_state_scripting( self, incremental_state, new_order, ): for module in incremental_state: for key in incremental_state[module]: result = incremental_state[module][key].index_select(0, new_order) incremental_state[module][key] = result @register_model_architecture("lm", "lm_base") def base_lm_architecture(args): # backward compatibility for older model checkpoints if hasattr(args, "no_tie_adaptive_proj"): # previous models defined --no-tie-adaptive-proj, so use the existence of # that option to determine if this is an "old" model checkpoint args.no_decoder_final_norm = True # old models always set this to True if args.no_tie_adaptive_proj is False: args.tie_adaptive_proj = True if hasattr(args, "decoder_final_norm"): args.no_decoder_final_norm = not args.decoder_final_norm args.dropout = getattr(args, "dropout", 0.1) args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) args.decoder_layers = getattr(args, "decoder_layers", 6) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) args.activation_fn = getattr(args, "activation_fn", "relu") args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) args.base_layers = getattr(args, "base_layers", 0) args.base_sublayers = getattr(args, "base_sublayers", 1) args.base_shuffle = getattr(args, "base_shuffle", False) args.add_bos_token = getattr(args, "add_bos_token", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) args.character_embeddings = getattr(args, "character_embeddings", False) args.decoder_output_dim = getattr( args, "decoder_output_dim", args.decoder_embed_dim ) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) # Model training is not stable without this args.decoder_normalize_before = True args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True