Support Pytorch LayerNorm

This commit is contained in:
shumingma 2023-01-16 20:17:28 -08:00
parent 82f140a6c4
commit 9f105b591d
7 changed files with 40 additions and 18 deletions

View File

@ -8,7 +8,6 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from fairseq import utils from fairseq import utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
@ -16,6 +15,10 @@ from fairseq.models.squad import SQuADHead
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from omegaconf import II from omegaconf import II
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.architecture.config import EncoderConfig from torchscale.architecture.config import EncoderConfig

View File

@ -17,7 +17,7 @@ setup(
license="MIT", license="MIT",
url="https://github.com/msranlp/torchscale", url="https://github.com/msranlp/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"], install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
python_requires=">=3.8.0", python_requires=">=3.8.0",
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",

View File

@ -39,6 +39,7 @@ class EncoderConfig(object):
) )
self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False) self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text # Text
self.vocab_size = kwargs.pop("vocab_size", -1) self.vocab_size = kwargs.pop("vocab_size", -1)
# Vision # Vision
@ -106,6 +107,7 @@ class DecoderConfig(object):
) )
self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False) self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text # Text
self.vocab_size = kwargs.pop("vocab_size", -1) self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale # Fairscale
@ -176,6 +178,7 @@ class EncoderDecoderConfig(object):
self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False) self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text # Text
self.vocab_size = kwargs.pop("vocab_size", -1) self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale # Fairscale

View File

@ -6,7 +6,6 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params from torchscale.architecture.utils import init_bert_params
@ -16,7 +15,10 @@ from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.relative_position_bias import RelativePositionBias from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__( def __init__(
@ -43,14 +45,14 @@ class DecoderLayer(nn.Module):
self.normalize_before = args.decoder_normalize_before self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
if not is_encoder_decoder: if not is_encoder_decoder:
self.encoder_attn = None self.encoder_attn = None
self.encoder_attn_layer_norm = None self.encoder_attn_layer_norm = None
else: else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim self.ffn_dim = args.decoder_ffn_embed_dim
@ -82,7 +84,7 @@ class DecoderLayer(nn.Module):
experts = make_experts(args, self.embed_dim, self.ffn_dim) experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args) self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm: if args.deepnorm:
if is_encoder_decoder: if is_encoder_decoder:
@ -99,6 +101,7 @@ class DecoderLayer(nn.Module):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
args.layernorm_eps,
args.subln, args.subln,
) )
@ -233,7 +236,7 @@ class Decoder(nn.Module):
self.output_projection = output_projection self.output_projection = output_projection
if args.layernorm_embedding: if args.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim) self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
else: else:
self.layernorm_embedding = None self.layernorm_embedding = None
@ -254,7 +257,7 @@ class Decoder(nn.Module):
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
if args.decoder_normalize_before: if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
else: else:
self.layer_norm = None self.layer_norm = None

View File

@ -6,8 +6,11 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap from fairscale.nn import checkpoint_wrapper, wrap
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.architecture.utils import init_bert_params from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath from torchscale.component.droppath import DropPath
@ -25,7 +28,7 @@ class EncoderLayer(nn.Module):
self.args = args self.args = args
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim)) self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0: if args.drop_path_rate > 0:
@ -70,7 +73,7 @@ class EncoderLayer(nn.Module):
) )
experts = make_experts(args, self.embed_dim, self.ffn_dim) experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args) self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim)) self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
if args.deepnorm: if args.deepnorm:
if is_encoder_decoder: if is_encoder_decoder:
@ -92,6 +95,7 @@ class EncoderLayer(nn.Module):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
args.layernorm_eps,
args.subln, args.subln,
) )
@ -190,7 +194,7 @@ class Encoder(nn.Module):
if args.layernorm_embedding: if args.layernorm_embedding:
self.layernorm_embedding = MultiwayWrapper( self.layernorm_embedding = MultiwayWrapper(
args, LayerNorm(embed_dim), dim=1 args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1
) )
else: else:
self.layernorm_embedding = None self.layernorm_embedding = None
@ -211,7 +215,7 @@ class Encoder(nn.Module):
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
if args.encoder_normalize_before: if args.encoder_normalize_before:
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim)) self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
else: else:
self.layer_norm = None self.layer_norm = None

View File

@ -4,7 +4,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
class set_torch_seed(object): class set_torch_seed(object):
@ -58,6 +61,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
args.layernorm_eps,
args.subln, args.subln,
) )
) )
@ -74,6 +78,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
args.layernorm_eps,
args.subln, args.subln,
) )
) )
@ -98,6 +103,7 @@ class FeedForwardNetwork(nn.Module):
activation_fn, activation_fn,
dropout, dropout,
activation_dropout, activation_dropout,
layernorm_eps,
subln=False, subln=False,
): ):
super().__init__() super().__init__()
@ -109,7 +115,7 @@ class FeedForwardNetwork(nn.Module):
self.dropout_module = torch.nn.Dropout(dropout, inplace=True) self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim) self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim) self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
def reset_parameters(self): def reset_parameters(self):
self.fc1.reset_parameters() self.fc1.reset_parameters()

View File

@ -5,8 +5,11 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from torch import nn from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .multiway_network import MultiwayWrapper from .multiway_network import MultiwayWrapper
from .xpos_relative_position import XPOS from .xpos_relative_position import XPOS
@ -41,7 +44,7 @@ class MultiheadAttention(nn.Module):
args, nn.Linear(embed_dim, embed_dim, bias=True) args, nn.Linear(embed_dim, embed_dim, bias=True)
) )
self.inner_attn_ln = ( self.inner_attn_ln = (
MultiwayWrapper(args, LayerNorm(self.embed_dim)) MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
if subln and self.self_attention if subln and self.self_attention
else None else None
) )