Merge pull request #30 from njb-ms/main

make pgs global
This commit is contained in:
Shuming Ma 2023-04-26 12:12:55 +08:00 committed by GitHub
commit b59b6f87b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 74 additions and 60 deletions

View File

@ -10,6 +10,9 @@ except ModuleNotFoundError:
from torch.nn import LayerNorm
from .xmoe.global_groups import get_moe_group
class set_torch_seed(object):
def __init__(self, seed):
assert isinstance(seed, int)
@ -70,7 +73,9 @@ def make_experts(args, embed_dim, expert_ffn_dim):
world_size % args.moe_expert_count == 0
), f"{world_size}, {args.moe_expert_count}"
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
moe_idx, _ = get_moe_group(args.moe_expert_count)
with set_torch_seed(start_seed + moe_idx):
expert_list.append(
FeedForwardNetwork(
embed_dim,

View File

@ -0,0 +1,65 @@
import torch.distributed as dist
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count=None):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_expert_count = moe_expert_count
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return my_group_idx, get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]

View File

@ -18,6 +18,8 @@ import torch.distributed as dist
from torch import Tensor
from torch.nn import Module, ModuleList
from .global_groups import get_all2all_group, get_moe_group
try:
from fairseq.modules.moe import MOELayer
@ -61,64 +63,6 @@ class _AllToAll(torch.autograd.Function):
return (None, _AllToAll.apply(ctx.group, *grad_output))
def _find_my_group_index(grouped_ranks):
my_rank = dist.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError
def get_moe_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size()
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
return get_moe_group._moe_groups[my_group_idx]
def get_all2all_group(moe_expert_count):
if dist.is_initialized():
if not hasattr(get_all2all_group, "_all2all_groups"):
world_size = dist.get_world_size()
# more experts than world size
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
all2all_groups = [[i for i in range(world_size)]]
# larger world than num experts
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]
class MOELayer(Base):
@ -149,7 +93,7 @@ class MOELayer(Base):
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.expert_group = get_moe_group(args.moe_expert_count)
_, self.expert_group = get_moe_group(args.moe_expert_count)
self.all2all_group = get_all2all_group(args.moe_expert_count)
self.world_size = dist.get_world_size(group=self.expert_group)
self.all2all_size = dist.get_world_size(group=self.all2all_group)