Support Pytorch LayerNorm

This commit is contained in:
shumingma 2023-01-16 20:17:28 -08:00
parent 82f140a6c4
commit 9f105b591d
7 changed files with 40 additions and 18 deletions

View File

@ -8,7 +8,6 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from fairseq import utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
@ -16,6 +15,10 @@ from fairseq.models.squad import SQuADHead
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from fairseq.modules import PositionalEmbedding
from omegaconf import II
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.architecture.config import EncoderConfig

View File

@ -17,7 +17,7 @@ setup(
license="MIT",
url="https://github.com/msranlp/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
python_requires=">=3.8.0",
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -39,6 +39,7 @@ class EncoderConfig(object):
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Vision
@ -106,6 +107,7 @@ class DecoderConfig(object):
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
@ -176,6 +178,7 @@ class EncoderDecoderConfig(object):
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale

View File

@ -6,7 +6,6 @@ import math
import numpy as np
import torch
import torch.nn as nn
from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
@ -16,7 +15,10 @@ from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
class DecoderLayer(nn.Module):
def __init__(
@ -43,14 +45,14 @@ class DecoderLayer(nn.Module):
self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
if not is_encoder_decoder:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
@ -82,7 +84,7 @@ class DecoderLayer(nn.Module):
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm:
if is_encoder_decoder:
@ -99,6 +101,7 @@ class DecoderLayer(nn.Module):
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
@ -233,7 +236,7 @@ class Decoder(nn.Module):
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
else:
self.layernorm_embedding = None
@ -254,7 +257,7 @@ class Decoder(nn.Module):
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim)
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
else:
self.layer_norm = None

View File

@ -6,8 +6,11 @@ import math
import numpy as np
import torch
import torch.nn as nn
from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
@ -25,7 +28,7 @@ class EncoderLayer(nn.Module):
self.args = args
self.embed_dim = args.encoder_embed_dim
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0:
@ -70,7 +73,7 @@ class EncoderLayer(nn.Module):
)
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
if args.deepnorm:
if is_encoder_decoder:
@ -92,6 +95,7 @@ class EncoderLayer(nn.Module):
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
@ -190,7 +194,7 @@ class Encoder(nn.Module):
if args.layernorm_embedding:
self.layernorm_embedding = MultiwayWrapper(
args, LayerNorm(embed_dim), dim=1
args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1
)
else:
self.layernorm_embedding = None
@ -211,7 +215,7 @@ class Encoder(nn.Module):
self.num_layers = len(self.layers)
if args.encoder_normalize_before:
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim))
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
else:
self.layer_norm = None

View File

@ -4,7 +4,10 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
class set_torch_seed(object):
@ -58,6 +61,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
)
@ -74,6 +78,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
)
@ -98,6 +103,7 @@ class FeedForwardNetwork(nn.Module):
activation_fn,
dropout,
activation_dropout,
layernorm_eps,
subln=False,
):
super().__init__()
@ -109,7 +115,7 @@ class FeedForwardNetwork(nn.Module):
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
def reset_parameters(self):
self.fc1.reset_parameters()

View File

@ -5,8 +5,11 @@ import math
import torch
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .multiway_network import MultiwayWrapper
from .xpos_relative_position import XPOS
@ -41,7 +44,7 @@ class MultiheadAttention(nn.Module):
args, nn.Linear(embed_dim, embed_dim, bias=True)
)
self.inner_attn_ln = (
MultiwayWrapper(args, LayerNorm(self.embed_dim))
MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
if subln and self.self_attention
else None
)