From 9f105b591d03ea762b853f0697ee6c2ab8b16ec6 Mon Sep 17 00:00:00 2001 From: shumingma Date: Mon, 16 Jan 2023 20:17:28 -0800 Subject: [PATCH] Support Pytorch LayerNorm --- examples/fairseq/models/bert.py | 5 ++++- setup.py | 2 +- torchscale/architecture/config.py | 3 +++ torchscale/architecture/decoder.py | 17 ++++++++++------- torchscale/architecture/encoder.py | 14 +++++++++----- torchscale/component/feedforward_network.py | 10 ++++++++-- torchscale/component/multihead_attention.py | 7 +++++-- 7 files changed, 40 insertions(+), 18 deletions(-) diff --git a/examples/fairseq/models/bert.py b/examples/fairseq/models/bert.py index 42e2687..8327484 100644 --- a/examples/fairseq/models/bert.py +++ b/examples/fairseq/models/bert.py @@ -8,7 +8,6 @@ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F -from apex.normalization import FusedLayerNorm as LayerNorm from fairseq import utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import BaseFairseqModel, register_model, register_model_architecture @@ -16,6 +15,10 @@ from fairseq.models.squad import SQuADHead from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding from fairseq.modules import PositionalEmbedding from omegaconf import II +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm from torchscale.architecture.config import EncoderConfig diff --git a/setup.py b/setup.py index eb6a3f2..4593330 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ setup( license="MIT", url="https://github.com/msranlp/torchscale", packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), - install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"], + install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"], python_requires=">=3.8.0", classifiers=[ "Programming Language :: Python :: 3", diff --git a/torchscale/architecture/config.py b/torchscale/architecture/config.py index 6b8bb5c..6aa7e16 100644 --- a/torchscale/architecture/config.py +++ b/torchscale/architecture/config.py @@ -39,6 +39,7 @@ class EncoderConfig(object): ) self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) # Text self.vocab_size = kwargs.pop("vocab_size", -1) # Vision @@ -106,6 +107,7 @@ class DecoderConfig(object): ) self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) # Text self.vocab_size = kwargs.pop("vocab_size", -1) # Fairscale @@ -176,6 +178,7 @@ class EncoderDecoderConfig(object): self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.no_output_layer = kwargs.pop("no_output_layer", False) + self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5) # Text self.vocab_size = kwargs.pop("vocab_size", -1) # Fairscale diff --git a/torchscale/architecture/decoder.py b/torchscale/architecture/decoder.py index 942a0cb..2dea15f 100644 --- a/torchscale/architecture/decoder.py +++ b/torchscale/architecture/decoder.py @@ -6,7 +6,6 @@ import math import numpy as np import torch import torch.nn as nn -from apex.normalization import FusedLayerNorm as LayerNorm from fairscale.nn import checkpoint_wrapper, wrap from torchscale.architecture.utils import init_bert_params @@ -16,7 +15,10 @@ from torchscale.component.multihead_attention import MultiheadAttention from torchscale.component.relative_position_bias import RelativePositionBias from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.routing import Top1Gate, Top2Gate - +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm class DecoderLayer(nn.Module): def __init__( @@ -43,14 +45,14 @@ class DecoderLayer(nn.Module): self.normalize_before = args.decoder_normalize_before - self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) if not is_encoder_decoder: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) - self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) self.is_moe_layer = is_moe_layer self.ffn_dim = args.decoder_ffn_embed_dim @@ -82,7 +84,7 @@ class DecoderLayer(nn.Module): experts = make_experts(args, self.embed_dim, self.ffn_dim) self.moe_layer = MOELayer(gate, experts, args) - self.final_layer_norm = LayerNorm(self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps) if args.deepnorm: if is_encoder_decoder: @@ -99,6 +101,7 @@ class DecoderLayer(nn.Module): args.activation_fn, args.dropout, args.activation_dropout, + args.layernorm_eps, args.subln, ) @@ -233,7 +236,7 @@ class Decoder(nn.Module): self.output_projection = output_projection if args.layernorm_embedding: - self.layernorm_embedding = LayerNorm(embed_dim) + self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps) else: self.layernorm_embedding = None @@ -254,7 +257,7 @@ class Decoder(nn.Module): self.num_layers = len(self.layers) if args.decoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim) + self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps) else: self.layer_norm = None diff --git a/torchscale/architecture/encoder.py b/torchscale/architecture/encoder.py index a9c6b78..c47238b 100644 --- a/torchscale/architecture/encoder.py +++ b/torchscale/architecture/encoder.py @@ -6,8 +6,11 @@ import math import numpy as np import torch import torch.nn as nn -from apex.normalization import FusedLayerNorm as LayerNorm from fairscale.nn import checkpoint_wrapper, wrap +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm from torchscale.architecture.utils import init_bert_params from torchscale.component.droppath import DropPath @@ -25,7 +28,7 @@ class EncoderLayer(nn.Module): self.args = args self.embed_dim = args.encoder_embed_dim self.self_attn = self.build_self_attention(self.embed_dim, args) - self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim)) + self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) if args.drop_path_rate > 0: @@ -70,7 +73,7 @@ class EncoderLayer(nn.Module): ) experts = make_experts(args, self.embed_dim, self.ffn_dim) self.moe_layer = MOELayer(gate, experts, args) - self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim)) + self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) if args.deepnorm: if is_encoder_decoder: @@ -92,6 +95,7 @@ class EncoderLayer(nn.Module): args.activation_fn, args.dropout, args.activation_dropout, + args.layernorm_eps, args.subln, ) @@ -190,7 +194,7 @@ class Encoder(nn.Module): if args.layernorm_embedding: self.layernorm_embedding = MultiwayWrapper( - args, LayerNorm(embed_dim), dim=1 + args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1 ) else: self.layernorm_embedding = None @@ -211,7 +215,7 @@ class Encoder(nn.Module): self.num_layers = len(self.layers) if args.encoder_normalize_before: - self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim)) + self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps)) else: self.layer_norm = None diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index 31c0651..0c872ce 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -4,7 +4,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from apex.normalization import FusedLayerNorm as LayerNorm +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm class set_torch_seed(object): @@ -58,6 +61,7 @@ def make_experts(args, embed_dim, expert_ffn_dim): args.activation_fn, args.dropout, args.activation_dropout, + args.layernorm_eps, args.subln, ) ) @@ -74,6 +78,7 @@ def make_experts(args, embed_dim, expert_ffn_dim): args.activation_fn, args.dropout, args.activation_dropout, + args.layernorm_eps, args.subln, ) ) @@ -98,6 +103,7 @@ class FeedForwardNetwork(nn.Module): activation_fn, dropout, activation_dropout, + layernorm_eps, subln=False, ): super().__init__() @@ -109,7 +115,7 @@ class FeedForwardNetwork(nn.Module): self.dropout_module = torch.nn.Dropout(dropout, inplace=True) self.fc1 = nn.Linear(self.embed_dim, ffn_dim) self.fc2 = nn.Linear(ffn_dim, self.embed_dim) - self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None + self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None def reset_parameters(self): self.fc1.reset_parameters() diff --git a/torchscale/component/multihead_attention.py b/torchscale/component/multihead_attention.py index b5f9c70..d255596 100644 --- a/torchscale/component/multihead_attention.py +++ b/torchscale/component/multihead_attention.py @@ -5,8 +5,11 @@ import math import torch import torch.nn.functional as F -from apex.normalization import FusedLayerNorm as LayerNorm from torch import nn +try: + from apex.normalization import FusedLayerNorm as LayerNorm +except ModuleNotFoundError: + from torch.nn import LayerNorm from .multiway_network import MultiwayWrapper from .xpos_relative_position import XPOS @@ -41,7 +44,7 @@ class MultiheadAttention(nn.Module): args, nn.Linear(embed_dim, embed_dim, bias=True) ) self.inner_attn_ln = ( - MultiwayWrapper(args, LayerNorm(self.embed_dim)) + MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps)) if subln and self.self_attention else None )