modify rms norm and value dim in retention

main
sunyt32 2023-09-28 14:24:37 +07:00
parent d1fefe9c22
commit 5c89ffbeea
6 changed files with 208 additions and 60 deletions

@ -31,8 +31,8 @@ logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
activation_fn: str = field(
default="swish", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
activation_dropout: float = field(
@ -44,6 +44,9 @@ class LanguageConfig(FairseqDataclass):
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_value_embed_dim: int = field(
default=864, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
@ -51,14 +54,14 @@ class LanguageConfig(FairseqDataclass):
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
default=864, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_retention_heads: int = field(
default=2, metadata={"help": "num decoder retention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each decoder block"}
default=False, metadata={"help": "apply norm before each decoder block"}
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
@ -67,8 +70,8 @@ class LanguageConfig(FairseqDataclass):
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
norm_embedding: bool = field(
default=False, metadata={"help": "add norm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
@ -276,14 +279,15 @@ def retnet_base_architecture(args):
args.dropout = getattr(args, "dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 864)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 864)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.activation_fn = getattr(args, "activation_fn", "swish")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
@ -321,7 +325,7 @@ def retnet_base_architecture(args):
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.norm_embedding = getattr(args, "norm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
@ -330,7 +334,8 @@ def retnet_base_architecture(args):
@register_model_architecture("retnet", "retnet_medium")
def retnet_medium(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 1728)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1728)
args.decoder_layers = getattr(args, "decoder_layers", 16)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
retnet_base_architecture(args)
@ -338,7 +343,8 @@ def retnet_medium(args):
@register_model_architecture("retnet", "retnet_xl")
def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 3456)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3456)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args)
@ -346,7 +352,8 @@ def retnet_xl(args):
@register_model_architecture("retnet", "retnet_3b")
def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 4280)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4280)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@ -354,7 +361,8 @@ def retnet_3b(args):
@register_model_architecture("retnet", "retnet_7b")
def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 6912)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6912)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@ -362,7 +370,8 @@ def retnet_7b(args):
@register_model_architecture("retnet", "retnet_13b")
def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 8560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8560)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args)
@ -370,7 +379,8 @@ def retnet_13b(args):
@register_model_architecture("retnet", "retnet_65b")
def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384)
args.decoder_value_embed_dim = getattr(args, "decoder_value_embed_dim", 13824)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 13824)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args)

@ -212,8 +212,9 @@ class EncoderDecoderConfig(object):
class RetNetConfig(object):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_value_embed_dim = kwargs.pop("decoder_value_embed_dim", 1280)
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1536)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1280)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
@ -221,7 +222,7 @@ class RetNetConfig(object):
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.norm_embedding = kwargs.pop("norm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
@ -244,7 +245,7 @@ class RetNetConfig(object):
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
self.norm_eps = kwargs.pop("norm_eps", 1e-6)
# Blockwise
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)

@ -11,14 +11,11 @@ from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.gate_linear_unit import GLU, make_experts
from torchscale.component.multiscale_retention import MultiScaleRetention
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from torchscale.component.rms_norm import RMSNorm
class RetNetRelPos(nn.Module):
@ -91,7 +88,7 @@ class DecoderLayer(nn.Module):
self.normalize_before = args.decoder_normalize_before
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.norm_eps)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
@ -123,7 +120,7 @@ class DecoderLayer(nn.Module):
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.norm_eps)
if args.deepnorm:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
@ -131,20 +128,19 @@ class DecoderLayer(nn.Module):
self.alpha = 1.0
def build_ffn(self, embed_dim, args):
return FeedForwardNetwork(
return GLU(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
def build_retention(self, embed_dim, args):
return MultiScaleRetention(
args,
embed_dim,
args.decoder_value_embed_dim,
args.decoder_retention_heads,
)
@ -224,10 +220,10 @@ class RetNetDecoder(nn.Module):
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
if args.norm_embedding:
self.norm_embedding = RMSNorm(embed_dim, eps=args.norm_eps)
else:
self.layernorm_embedding = None
self.norm_embedding = None
self.layers = nn.ModuleList([])
@ -245,7 +241,7 @@ class RetNetDecoder(nn.Module):
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
self.layer_norm = RMSNorm(embed_dim, eps=args.norm_eps)
else:
self.layer_norm = None
@ -265,17 +261,6 @@ class RetNetDecoder(nn.Module):
):
p.data.div_(init_scale)
if args.subln:
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
def build_output_projection(
self,
args,
@ -324,8 +309,8 @@ class RetNetDecoder(nn.Module):
x = embed = self.embed_scale * token_embedding
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
if self.norm_embedding is not None:
x = self.norm_embedding(x)
x = self.dropout_module(x)

@ -0,0 +1,132 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
import torch.nn.functional as F
from .xmoe.global_groups import get_moe_group
class set_torch_seed(object):
def __init__(self, seed):
assert isinstance(seed, int)
self.rng_state = self.get_rng_state()
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def get_rng_state(self):
state = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
state["cuda_rng_state"] = torch.cuda.get_rng_state()
return state
def set_rng_state(self, state):
torch.set_rng_state(state["torch_rng_state"])
if torch.cuda.is_available():
torch.cuda.set_rng_state(state["cuda_rng_state"])
def __enter__(self):
return self
def __exit__(self, *exc):
self.set_rng_state(self.rng_state)
def make_experts(args, embed_dim, expert_ffn_dim):
world_size = (
1
if not torch.distributed.is_initialized()
else torch.distributed.get_world_size()
)
expert_list = []
ddp_rank = args.ddp_rank
start_seed = torch.randint(1000000, (1,)).item()
# at least as many experts than gpus
if args.moe_expert_count >= world_size:
assert (
args.moe_expert_count % world_size == 0
), f"{args.moe_expert_count}, {world_size}"
local_moe_expert_count = args.moe_expert_count // world_size
for i in range(local_moe_expert_count):
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
expert_list.append(
GLU(
embed_dim,
expert_ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
)
else:
assert (
world_size % args.moe_expert_count == 0
), f"{world_size}, {args.moe_expert_count}"
moe_idx, _ = get_moe_group(args.moe_expert_count)
with set_torch_seed(start_seed + moe_idx):
expert_list.append(
GLU(
embed_dim,
expert_ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
)
experts = nn.ModuleList(expert_list)
return experts
def get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
elif activation == "swish":
return F.silu
else:
raise NotImplementedError
class GLU(nn.Module):
def __init__(
self,
embed_dim,
ffn_dim,
activation_fn,
dropout,
activation_dropout,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
def reset_parameters(self):
self.fc1.reset_parameters()
self.fc2.reset_parameters()
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x.size(-1))
g = self.gate(x)
x = self.fc1(x)
x = self.activation_fn(x.float()).type_as(x) * g
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = x.view(x_shape)
x = self.dropout_module(x)
return x

@ -1,15 +1,11 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import torch.nn.functional as F
from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .rms_norm import RMSNorm
from .multiway_network import MultiwayWrapper
@ -45,29 +41,29 @@ class MultiScaleRetention(nn.Module):
self,
args,
embed_dim,
value_dim,
num_heads,
value_factor=2,
gate_fn="swish",
):
super().__init__()
self.args = args
self.factor = value_factor
self.embed_dim = embed_dim
self.value_dim = value_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim * self.factor // num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim ** -0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim * self.factor, embed_dim, bias=True))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.norm_eps, elementwise_affine=False))
self.reset_parameters()
def reset_parameters(self):
@ -76,7 +72,6 @@ class MultiScaleRetention(nn.Module):
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.constant_(self.out_proj.bias, 0.0)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size()

@ -0,0 +1,25 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output