From 886c8ab408bdd350e992cad5151a18eca637ab7f Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Sat, 22 Apr 2023 00:32:05 +0000 Subject: [PATCH] make pgs global --- torchscale/component/feedforward_network.py | 7 ++- torchscale/component/xmoe/global_groups.py | 64 +++++++++++++++++++++ torchscale/component/xmoe/moe_layer.py | 62 +------------------- 3 files changed, 73 insertions(+), 60 deletions(-) create mode 100644 torchscale/component/xmoe/global_groups.py diff --git a/torchscale/component/feedforward_network.py b/torchscale/component/feedforward_network.py index abea43b..cc187a8 100644 --- a/torchscale/component/feedforward_network.py +++ b/torchscale/component/feedforward_network.py @@ -10,6 +10,9 @@ except ModuleNotFoundError: from torch.nn import LayerNorm +from .xmoe.global_groups import get_moe_group + + class set_torch_seed(object): def __init__(self, seed): assert isinstance(seed, int) @@ -70,7 +73,9 @@ def make_experts(args, embed_dim, expert_ffn_dim): world_size % args.moe_expert_count == 0 ), f"{world_size}, {args.moe_expert_count}" - with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): + moe_idx, _ = get_moe_group(args.moe_expert_count) + + with set_torch_seed(start_seed + moe_idx): expert_list.append( FeedForwardNetwork( embed_dim, diff --git a/torchscale/component/xmoe/global_groups.py b/torchscale/component/xmoe/global_groups.py new file mode 100644 index 0000000..1846913 --- /dev/null +++ b/torchscale/component/xmoe/global_groups.py @@ -0,0 +1,64 @@ +import torch.distributed as dist + + +def _find_my_group_index(grouped_ranks): + my_rank = dist.get_rank() + for i, group in enumerate(grouped_ranks): + if my_rank in group: + return i + raise RuntimeError + +def get_moe_group(moe_expert_count): + if dist.is_initialized(): + if not hasattr(get_moe_group, "_moe_groups"): + world_size = dist.get_world_size() + + if world_size <= moe_expert_count: + assert moe_expert_count % world_size == 0 + moe_groups = [[i] for i in range(world_size)] + + else: + assert world_size % moe_expert_count == 0 + ranks_per_group = world_size // moe_expert_count + moe_groups = [ + [i + j * moe_expert_count for j in range(ranks_per_group)] + for i in range(moe_expert_count) + ] + + get_moe_group._moe_group_idx = moe_groups + get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] + + my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) + return my_group_idx, get_moe_group._moe_groups[my_group_idx] + + +def get_all2all_group(moe_expert_count): + if dist.is_initialized(): + if not hasattr(get_all2all_group, "_all2all_groups"): + world_size = dist.get_world_size() + + # more experts than world size + if world_size <= moe_expert_count: + assert moe_expert_count % world_size == 0 + all2all_groups = [[i for i in range(world_size)]] + + # larger world than num experts + else: + assert world_size % moe_expert_count == 0 + ranks_per_group = world_size // moe_expert_count + all2all_groups = [ + [i * moe_expert_count + j for j in range(moe_expert_count)] + for i in range(ranks_per_group) + ] + + get_all2all_group._all2all_group_idx = all2all_groups + get_all2all_group._all2all_groups = [ + dist.new_group(g) for g in all2all_groups + ] + + my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) + return get_all2all_group._all2all_groups[my_group_idx] + + + + diff --git a/torchscale/component/xmoe/moe_layer.py b/torchscale/component/xmoe/moe_layer.py index fe5d691..51e7713 100644 --- a/torchscale/component/xmoe/moe_layer.py +++ b/torchscale/component/xmoe/moe_layer.py @@ -18,6 +18,8 @@ import torch.distributed as dist from torch import Tensor from torch.nn import Module, ModuleList +from .global_groups import get_all2all_group, get_moe_group + try: from fairseq.modules.moe import MOELayer @@ -61,64 +63,6 @@ class _AllToAll(torch.autograd.Function): return (None, _AllToAll.apply(ctx.group, *grad_output)) -def _find_my_group_index(grouped_ranks): - my_rank = dist.get_rank() - for i, group in enumerate(grouped_ranks): - if my_rank in group: - return i - raise RuntimeError - - -def get_moe_group(moe_expert_count): - if dist.is_initialized(): - if not hasattr(get_moe_group, "_moe_groups"): - world_size = dist.get_world_size() - - if world_size <= moe_expert_count: - assert moe_expert_count % world_size == 0 - moe_groups = [[i] for i in range(world_size)] - - else: - assert world_size % moe_expert_count == 0 - ranks_per_group = world_size // moe_expert_count - moe_groups = [ - [i + j * moe_expert_count for j in range(ranks_per_group)] - for i in range(moe_expert_count) - ] - - get_moe_group._moe_group_idx = moe_groups - get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] - - my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx) - return get_moe_group._moe_groups[my_group_idx] - - -def get_all2all_group(moe_expert_count): - if dist.is_initialized(): - if not hasattr(get_all2all_group, "_all2all_groups"): - world_size = dist.get_world_size() - - # more experts than world size - if world_size <= moe_expert_count: - assert moe_expert_count % world_size == 0 - all2all_groups = [[i for i in range(world_size)]] - - # larger world than num experts - else: - assert world_size % moe_expert_count == 0 - ranks_per_group = world_size // moe_expert_count - all2all_groups = [ - [i * moe_expert_count + j for j in range(moe_expert_count)] - for i in range(ranks_per_group) - ] - - get_all2all_group._all2all_group_idx = all2all_groups - get_all2all_group._all2all_groups = [ - dist.new_group(g) for g in all2all_groups - ] - - my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) - return get_all2all_group._all2all_groups[my_group_idx] class MOELayer(Base): @@ -149,7 +93,7 @@ class MOELayer(Base): self.experts = cast(ModuleList, experts) else: self.experts = ModuleList([experts]) - self.expert_group = get_moe_group(args.moe_expert_count) + _, self.expert_group = get_moe_group(args.moe_expert_count) self.all2all_group = get_all2all_group(args.moe_expert_count) self.world_size = dist.get_world_size(group=self.expert_group) self.all2all_size = dist.get_world_size(group=self.all2all_group)