modify rms norm and value dim in retention
parent
d1fefe9c22
commit
5c89ffbeea
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .xmoe.global_groups import get_moe_group
|
||||
|
||||
|
||||
class set_torch_seed(object):
|
||||
def __init__(self, seed):
|
||||
assert isinstance(seed, int)
|
||||
self.rng_state = self.get_rng_state()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def get_rng_state(self):
|
||||
state = {"torch_rng_state": torch.get_rng_state()}
|
||||
if torch.cuda.is_available():
|
||||
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
||||
return state
|
||||
|
||||
def set_rng_state(self, state):
|
||||
torch.set_rng_state(state["torch_rng_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.set_rng_state(self.rng_state)
|
||||
|
||||
|
||||
def make_experts(args, embed_dim, expert_ffn_dim):
|
||||
world_size = (
|
||||
1
|
||||
if not torch.distributed.is_initialized()
|
||||
else torch.distributed.get_world_size()
|
||||
)
|
||||
expert_list = []
|
||||
ddp_rank = args.ddp_rank
|
||||
start_seed = torch.randint(1000000, (1,)).item()
|
||||
# at least as many experts than gpus
|
||||
if args.moe_expert_count >= world_size:
|
||||
assert (
|
||||
args.moe_expert_count % world_size == 0
|
||||
), f"{args.moe_expert_count}, {world_size}"
|
||||
local_moe_expert_count = args.moe_expert_count // world_size
|
||||
for i in range(local_moe_expert_count):
|
||||
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
||||
expert_list.append(
|
||||
GLU(
|
||||
embed_dim,
|
||||
expert_ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
world_size % args.moe_expert_count == 0
|
||||
), f"{world_size}, {args.moe_expert_count}"
|
||||
|
||||
moe_idx, _ = get_moe_group(args.moe_expert_count)
|
||||
|
||||
with set_torch_seed(start_seed + moe_idx):
|
||||
expert_list.append(
|
||||
GLU(
|
||||
embed_dim,
|
||||
expert_ffn_dim,
|
||||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.layernorm_eps,
|
||||
args.subln,
|
||||
)
|
||||
)
|
||||
experts = nn.ModuleList(expert_list)
|
||||
return experts
|
||||
|
||||
|
||||
def get_activation_fn(activation):
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
elif activation == "gelu":
|
||||
return F.gelu
|
||||
elif activation == "swish":
|
||||
return F.silu
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
ffn_dim,
|
||||
activation_fn,
|
||||
dropout,
|
||||
activation_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
||||
self.dropout_module = torch.nn.Dropout(dropout)
|
||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
||||
self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.fc1.reset_parameters()
|
||||
self.fc2.reset_parameters()
|
||||
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
g = self.gate(x)
|
||||
x = self.fc1(x)
|
||||
x = self.activation_fn(x.float()).type_as(x) * g
|
||||
x = self.activation_dropout_module(x)
|
||||
x = self.fc2(x)
|
||||
x = x.view(x_shape)
|
||||
x = self.dropout_module(x)
|
||||
return x
|
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if self.weight is not None:
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
Loading…
Reference in New Issue