Support Pytorch LayerNorm
This commit is contained in:
parent
82f140a6c4
commit
9f105b591d
|
@ -8,7 +8,6 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from fairseq import utils
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||
|
@ -16,6 +15,10 @@ from fairseq.models.squad import SQuADHead
|
|||
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from omegaconf import II
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -17,7 +17,7 @@ setup(
|
|||
license="MIT",
|
||||
url="https://github.com/msranlp/torchscale",
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||
install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
||||
install_requires=["torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
||||
python_requires=">=3.8.0",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
@ -39,6 +39,7 @@ class EncoderConfig(object):
|
|||
)
|
||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
||||
# Text
|
||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||
# Vision
|
||||
|
@ -106,6 +107,7 @@ class DecoderConfig(object):
|
|||
)
|
||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
||||
# Text
|
||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||
# Fairscale
|
||||
|
@ -176,6 +178,7 @@ class EncoderDecoderConfig(object):
|
|||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
|
||||
# Text
|
||||
self.vocab_size = kwargs.pop("vocab_size", -1)
|
||||
# Fairscale
|
||||
|
|
|
@ -6,7 +6,6 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
|
@ -16,7 +15,10 @@ from torchscale.component.multihead_attention import MultiheadAttention
|
|||
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
|
@ -43,14 +45,14 @@ class DecoderLayer(nn.Module):
|
|||
|
||||
self.normalize_before = args.decoder_normalize_before
|
||||
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||
|
||||
if not is_encoder_decoder:
|
||||
self.encoder_attn = None
|
||||
self.encoder_attn_layer_norm = None
|
||||
else:
|
||||
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||
|
||||
self.is_moe_layer = is_moe_layer
|
||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||
|
@ -82,7 +84,7 @@ class DecoderLayer(nn.Module):
|
|||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||
self.moe_layer = MOELayer(gate, experts, args)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
|
@ -99,6 +101,7 @@ class DecoderLayer(nn.Module):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
|
||||
|
@ -233,7 +236,7 @@ class Decoder(nn.Module):
|
|||
self.output_projection = output_projection
|
||||
|
||||
if args.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
|
@ -254,7 +257,7 @@ class Decoder(nn.Module):
|
|||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
|
|
|
@ -6,8 +6,11 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.droppath import DropPath
|
||||
|
@ -25,7 +28,7 @@ class EncoderLayer(nn.Module):
|
|||
self.args = args
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
||||
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||
self.self_attn_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
if args.drop_path_rate > 0:
|
||||
|
@ -70,7 +73,7 @@ class EncoderLayer(nn.Module):
|
|||
)
|
||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||
self.moe_layer = MOELayer(gate, experts, args)
|
||||
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||
self.final_layer_norm = MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
||||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
|
@ -92,6 +95,7 @@ class EncoderLayer(nn.Module):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
|
||||
|
@ -190,7 +194,7 @@ class Encoder(nn.Module):
|
|||
|
||||
if args.layernorm_embedding:
|
||||
self.layernorm_embedding = MultiwayWrapper(
|
||||
args, LayerNorm(embed_dim), dim=1
|
||||
args, LayerNorm(embed_dim, eps=args.layernorm_eps), dim=1
|
||||
)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
@ -211,7 +215,7 @@ class Encoder(nn.Module):
|
|||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim))
|
||||
self.layer_norm = MultiwayWrapper(args, LayerNorm(embed_dim, eps=args.layernorm_eps))
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
|
|
|
@ -4,7 +4,10 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
|
||||
class set_torch_seed(object):
|
||||
|
@ -58,6 +61,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
)
|
||||
|
@ -74,6 +78,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
)
|
||||
|
@ -98,6 +103,7 @@ class FeedForwardNetwork(nn.Module):
|
|||
activation_fn,
|
||||
dropout,
|
||||
activation_dropout,
|
||||
layernorm_eps,
|
||||
subln=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -109,7 +115,7 @@ class FeedForwardNetwork(nn.Module):
|
|||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
||||
self.ffn_layernorm = LayerNorm(ffn_dim) if subln else None
|
||||
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
||||
|
||||
def reset_parameters(self):
|
||||
self.fc1.reset_parameters()
|
||||
|
|
|
@ -5,8 +5,11 @@ import math
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from torch import nn
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
except ModuleNotFoundError:
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .multiway_network import MultiwayWrapper
|
||||
from .xpos_relative_position import XPOS
|
||||
|
@ -41,7 +44,7 @@ class MultiheadAttention(nn.Module):
|
|||
args, nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
)
|
||||
self.inner_attn_ln = (
|
||||
MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||
MultiwayWrapper(args, LayerNorm(self.embed_dim, eps=args.layernorm_eps))
|
||||
if subln and self.self_attention
|
||||
else None
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user