143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
# Copyright (c) 2022 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
try:
|
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
|
except ModuleNotFoundError:
|
|
from torch.nn import LayerNorm
|
|
|
|
|
|
from .xmoe.global_groups import get_moe_group
|
|
|
|
|
|
class set_torch_seed(object):
|
|
def __init__(self, seed):
|
|
assert isinstance(seed, int)
|
|
self.rng_state = self.get_rng_state()
|
|
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
def get_rng_state(self):
|
|
state = {"torch_rng_state": torch.get_rng_state()}
|
|
if torch.cuda.is_available():
|
|
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
|
return state
|
|
|
|
def set_rng_state(self, state):
|
|
torch.set_rng_state(state["torch_rng_state"])
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.set_rng_state(self.rng_state)
|
|
|
|
|
|
def make_experts(args, embed_dim, expert_ffn_dim):
|
|
world_size = (
|
|
1
|
|
if not torch.distributed.is_initialized()
|
|
else torch.distributed.get_world_size()
|
|
)
|
|
expert_list = []
|
|
ddp_rank = args.ddp_rank
|
|
start_seed = torch.randint(1000000, (1,)).item()
|
|
# at least as many experts than gpus
|
|
if args.moe_expert_count >= world_size:
|
|
assert (
|
|
args.moe_expert_count % world_size == 0
|
|
), f"{args.moe_expert_count}, {world_size}"
|
|
local_moe_expert_count = args.moe_expert_count // world_size
|
|
for i in range(local_moe_expert_count):
|
|
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
|
expert_list.append(
|
|
FeedForwardNetwork(
|
|
embed_dim,
|
|
expert_ffn_dim,
|
|
args.activation_fn,
|
|
args.dropout,
|
|
args.activation_dropout,
|
|
args.layernorm_eps,
|
|
args.subln,
|
|
)
|
|
)
|
|
else:
|
|
assert (
|
|
world_size % args.moe_expert_count == 0
|
|
), f"{world_size}, {args.moe_expert_count}"
|
|
|
|
moe_idx, _ = get_moe_group(args.moe_expert_count)
|
|
|
|
with set_torch_seed(start_seed + moe_idx):
|
|
expert_list.append(
|
|
FeedForwardNetwork(
|
|
embed_dim,
|
|
expert_ffn_dim,
|
|
args.activation_fn,
|
|
args.dropout,
|
|
args.activation_dropout,
|
|
args.layernorm_eps,
|
|
args.subln,
|
|
)
|
|
)
|
|
experts = nn.ModuleList(expert_list)
|
|
return experts
|
|
|
|
|
|
def get_activation_fn(activation):
|
|
if activation == "relu":
|
|
return F.relu
|
|
elif activation == "gelu":
|
|
return F.gelu
|
|
elif activation == "swish":
|
|
return F.silu
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
class FeedForwardNetwork(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
ffn_dim,
|
|
activation_fn,
|
|
dropout,
|
|
activation_dropout,
|
|
layernorm_eps,
|
|
subln=False,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
|
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
|
self.dropout_module = torch.nn.Dropout(dropout)
|
|
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
|
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
|
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
|
|
|
def reset_parameters(self):
|
|
self.fc1.reset_parameters()
|
|
self.fc2.reset_parameters()
|
|
if self.ffn_layernorm is not None:
|
|
self.ffn_layernorm.reset_parameters()
|
|
|
|
def forward(self, x):
|
|
x_shape = x.shape
|
|
x = x.reshape(-1, x.size(-1))
|
|
x = self.fc1(x)
|
|
x = self.activation_fn(x.float()).type_as(x)
|
|
x = self.activation_dropout_module(x)
|
|
if self.ffn_layernorm is not None:
|
|
x = self.ffn_layernorm(x)
|
|
x = self.fc2(x)
|
|
x = x.view(x_shape)
|
|
x = self.dropout_module(x)
|
|
return x
|