diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 5e7ac92..14e7bfc 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -2,57 +2,72 @@ # Licensed under The MIT License [see LICENSE for details] -class EncoderConfig: +class Config: def __init__(self, **kwargs): + self.activation_fn = kwargs.pop("activation_fn", "gelu") + self.dropout = kwargs.pop("dropout", 0.0) + self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) + self.attention_dropout = kwargs.pop("attention_dropout", 0.0) + self.activation_dropout = kwargs.pop("activation_dropout", 0.0) + self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) + self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) + self.moe_freq = kwargs.pop("moe_freq", 0) + self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) + self.moe_expert_count = kwargs.pop("moe_expert_count", 0) + self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) + self.moe_eval_capacity_token_fraction = kwargs.pop( + "moe_eval_capacity_token_fraction", 0.25 + ) + self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") + self.moe_normalize_gate_prob_before_dropping = kwargs.pop( + "moe_normalize_gate_prob_before_dropping", False + ) + self.use_xmoe = kwargs.pop("use_xmoe", False) + self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) + self.max_rel_pos = kwargs.pop("max_rel_pos", 0) + self.deepnorm = kwargs.pop("deepnorm", False) + self.subln = kwargs.pop("subln", True) + self.bert_init = kwargs.pop("bert_init", False) + self.multiway = kwargs.pop("multiway", False) + self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) + self.vocab_size = kwargs.pop("vocab_size", -1) + self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) + self.fsdp = kwargs.pop("fsdp", False) + self.ddp_rank = kwargs.pop("ddp_rank", 0) + self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) + self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + + if self.use_xmoe: + self.moe_normalize_gate_prob_before_dropping = True + self.moe_second_expert_policy = "random" + assert self.moe_freq > 0 and self.moe_expert_count > 0 + + def override(self, args): + for hp in self.__dict__.keys(): + if getattr(args, hp, None) is not None: + self.__dict__[hp] = getattr(args, hp, None) + + +class EncoderConfig(Config): + def __init__(self, **kwargs): + super().__init__(**kwargs) self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) self.encoder_layers = kwargs.pop("encoder_layers", 12) self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) self.normalize_output = kwargs.pop("normalize_output", True) - self.activation_fn = kwargs.pop("activation_fn", "gelu") - self.dropout = kwargs.pop("dropout", 0.0) - self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) - self.attention_dropout = kwargs.pop("attention_dropout", 0.0) - self.activation_dropout = kwargs.pop("activation_dropout", 0.0) - self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) - self.moe_freq = kwargs.pop("moe_freq", 0) - self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) - self.moe_expert_count = kwargs.pop("moe_expert_count", 0) - self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) - self.moe_eval_capacity_token_fraction = kwargs.pop( - "moe_eval_capacity_token_fraction", 0.25 - ) - self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") - self.moe_normalize_gate_prob_before_dropping = kwargs.pop( - "moe_normalize_gate_prob_before_dropping", False - ) - self.use_xmoe = kwargs.pop("use_xmoe", False) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) - self.max_rel_pos = kwargs.pop("max_rel_pos", 0) - self.deepnorm = kwargs.pop("deepnorm", False) - self.subln = kwargs.pop("subln", True) - self.bert_init = kwargs.pop("bert_init", False) - self.multiway = kwargs.pop("multiway", False) self.share_encoder_input_output_embed = kwargs.pop( "share_encoder_input_output_embed", False ) self.max_source_positions = kwargs.pop("max_source_positions", 1024) - self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) - # Text - self.vocab_size = kwargs.pop("vocab_size", -1) # Vision self.img_size = kwargs.pop("img_size", 224) self.patch_size = kwargs.pop("patch_size", 16) self.in_chans = kwargs.pop("in_chans", 3) - # Fairscale - self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) - self.fsdp = kwargs.pop("fsdp", False) - self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) - self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) + if self.deepnorm: self.encoder_normalize_before = False @@ -60,63 +75,22 @@ class EncoderConfig: if self.subln: self.encoder_normalize_before = True self.deepnorm = False - if self.use_xmoe: - self.moe_normalize_gate_prob_before_dropping = True - self.moe_second_expert_policy = "random" - assert self.moe_freq > 0 and self.moe_expert_count > 0 - - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) + -class DecoderConfig: +class DecoderConfig(Config): def __init__(self, **kwargs): + super().__init__(**kwargs) self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) self.decoder_layers = kwargs.pop("decoder_layers", 12) self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) - self.activation_fn = kwargs.pop("activation_fn", "gelu") - self.dropout = kwargs.pop("dropout", 0.0) - self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) - self.attention_dropout = kwargs.pop("attention_dropout", 0.0) - self.activation_dropout = kwargs.pop("activation_dropout", 0.0) - self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) - self.moe_freq = kwargs.pop("moe_freq", 0) - self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) - self.moe_expert_count = kwargs.pop("moe_expert_count", 0) - self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) - self.moe_eval_capacity_token_fraction = kwargs.pop( - "moe_eval_capacity_token_fraction", 0.25 - ) - self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") - self.moe_normalize_gate_prob_before_dropping = kwargs.pop( - "moe_normalize_gate_prob_before_dropping", False - ) - self.use_xmoe = kwargs.pop("use_xmoe", False) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) - self.max_rel_pos = kwargs.pop("max_rel_pos", 0) - self.deepnorm = kwargs.pop("deepnorm", False) - self.subln = kwargs.pop("subln", True) - self.bert_init = kwargs.pop("bert_init", False) - self.multiway = kwargs.pop("multiway", False) self.share_decoder_input_output_embed = kwargs.pop( "share_decoder_input_output_embed", False ) self.max_target_positions = kwargs.pop("max_target_positions", 1024) - self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) - # Text - self.vocab_size = kwargs.pop("vocab_size", -1) - # Fairscale - self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) - self.fsdp = kwargs.pop("fsdp", False) - self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) - self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) if self.deepnorm: self.decoder_normalize_before = False @@ -124,85 +98,10 @@ class DecoderConfig: if self.subln: self.decoder_normalize_before = True self.deepnorm = False - if self.use_xmoe: - self.moe_normalize_gate_prob_before_dropping = True - self.moe_second_expert_policy = "random" - assert self.moe_freq > 0 and self.moe_expert_count > 0 - - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) -class EncoderDecoderConfig: + +class EncoderDecoderConfig(EncoderConfig, DecoderConfig): def __init__(self, **kwargs): - self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) - self.encoder_attention_heads = kwargs.pop("encoder_attention_heads", 12) - self.encoder_ffn_embed_dim = kwargs.pop("encoder_ffn_embed_dim", 3072) - self.encoder_layers = kwargs.pop("encoder_layers", 12) - self.encoder_normalize_before = kwargs.pop("encoder_normalize_before", True) - self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768) - self.decoder_attention_heads = kwargs.pop("decoder_attention_heads", 12) - self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 3072) - self.decoder_layers = kwargs.pop("decoder_layers", 12) - self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True) - self.activation_fn = kwargs.pop("activation_fn", "gelu") - self.dropout = kwargs.pop("dropout", 0.0) - self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0) - self.attention_dropout = kwargs.pop("attention_dropout", 0.0) - self.activation_dropout = kwargs.pop("activation_dropout", 0.0) - self.no_scale_embedding = kwargs.pop("no_scale_embedding", True) - self.layernorm_embedding = kwargs.pop("layernorm_embedding", False) - self.moe_freq = kwargs.pop("moe_freq", 0) - self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) - self.moe_expert_count = kwargs.pop("moe_expert_count", 0) - self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) - self.moe_eval_capacity_token_fraction = kwargs.pop( - "moe_eval_capacity_token_fraction", 0.25 - ) - self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") - self.moe_normalize_gate_prob_before_dropping = kwargs.pop( - "moe_normalize_gate_prob_before_dropping", False - ) - self.use_xmoe = kwargs.pop("use_xmoe", False) - self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) - self.max_rel_pos = kwargs.pop("max_rel_pos", 0) - self.deepnorm = kwargs.pop("deepnorm", False) - self.subln = kwargs.pop("subln", True) - self.bert_init = kwargs.pop("bert_init", False) - self.multiway = kwargs.pop("multiway", False) + super().__init__(**kwargs) self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) - self.share_decoder_input_output_embed = kwargs.pop( - "share_decoder_input_output_embed", False - ) - self.max_source_positions = kwargs.pop("max_source_positions", 1024) - self.max_target_positions = kwargs.pop("max_target_positions", 1024) - self.no_output_layer = kwargs.pop("no_output_layer", False) - self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) - # Text - self.vocab_size = kwargs.pop("vocab_size", -1) - # Fairscale - self.checkpoint_activations = kwargs.pop("checkpoint_activations", False) - self.fsdp = kwargs.pop("fsdp", False) - self.ddp_rank = kwargs.pop("ddp_rank", 0) - self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False) - self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512) - - if self.deepnorm: - self.encoder_normalize_before = False - self.decoder_normalize_before = False - self.subln = False - if self.subln: - self.encoder_normalize_before = True - self.decoder_normalize_before = True - self.deepnorm = False - if self.use_xmoe: - self.moe_normalize_gate_prob_before_dropping = True - self.moe_second_expert_policy = "random" - assert self.moe_freq > 0 and self.moe_expert_count > 0 - - def override(self, args): - for hp in self.__dict__.keys(): - if getattr(args, hp, None) is not None: - self.__dict__[hp] = getattr(args, hp, None) diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index ed407b0..e328834 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -9,12 +9,14 @@ import torch.nn as nn from fairscale.nn import checkpoint_wrapper, wrap from torchscale.architecture.utils import init_bert_params +from torchscale.architecture.config import DecoderConfig, EncoderConfig from torchscale.component.droppath import DropPath from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from torchscale.component.multihead_attention import MultiheadAttention from torchscale.component.relative_position_bias import RelativePositionBias from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.routing import Top1Gate, Top2Gate + try: from apex.normalization import FusedLayerNorm as LayerNorm except ModuleNotFoundError: @@ -23,7 +25,7 @@ except ModuleNotFoundError: class DecoderLayer(nn.Module): def __init__( self, - args, + args: DecoderConfig, depth, is_moe_layer=False, is_encoder_decoder=False, @@ -209,7 +211,7 @@ class DecoderLayer(nn.Module): class Decoder(nn.Module): def __init__( self, - args, + args: DecoderConfig, embed_tokens=None, embed_positions=None, output_projection=None, diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index 62ab174..c20fa65 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -13,6 +13,7 @@ except ModuleNotFoundError: from torch.nn import LayerNorm from torchscale.architecture.utils import init_bert_params +from torchscale.architecture.config import EncoderConfig from torchscale.component.droppath import DropPath from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from torchscale.component.multihead_attention import MultiheadAttention @@ -23,7 +24,11 @@ from torchscale.component.xmoe.routing import Top1Gate, Top2Gate class EncoderLayer(nn.Module): - def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False): + def __init__(self, + args: EncoderConfig, + depth, + is_moe_layer: bool = False, + is_encoder_decoder: bool = False): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim @@ -165,11 +170,11 @@ class EncoderLayer(nn.Module): class Encoder(nn.Module): def __init__( self, - args, + args: EncoderConfig, embed_tokens=None, embed_positions=None, output_projection=None, - is_encoder_decoder=False, + is_encoder_decoder: bool = False, **kwargs ): self.args = args diff --git a/torchscale/architecture/encoder_decoder.py b/torchscale/architecture/encoder_decoder.py index 91a906e..d91313f 100644 --- a/torchscale/architecture/encoder_decoder.py +++ b/torchscale/architecture/encoder_decoder.py @@ -3,6 +3,7 @@ import torch.nn as nn +from torchscale.architecture.config import EncoderDecoderConfig from torchscale.architecture.decoder import Decoder from torchscale.architecture.encoder import Encoder