make pgs global
This commit is contained in:
parent
4ae3b248ee
commit
886c8ab408
|
@ -10,6 +10,9 @@ except ModuleNotFoundError:
|
|||
from torch.nn import LayerNorm
|
||||
|
||||
|
||||
from .xmoe.global_groups import get_moe_group
|
||||
|
||||
|
||||
class set_torch_seed(object):
|
||||
def __init__(self, seed):
|
||||
assert isinstance(seed, int)
|
||||
|
@ -70,7 +73,9 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
|||
world_size % args.moe_expert_count == 0
|
||||
), f"{world_size}, {args.moe_expert_count}"
|
||||
|
||||
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
||||
moe_idx, _ = get_moe_group(args.moe_expert_count)
|
||||
|
||||
with set_torch_seed(start_seed + moe_idx):
|
||||
expert_list.append(
|
||||
FeedForwardNetwork(
|
||||
embed_dim,
|
||||
|
|
64
torchscale/component/xmoe/global_groups.py
Normal file
64
torchscale/component/xmoe/global_groups.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
|
||||
def _find_my_group_index(grouped_ranks):
|
||||
my_rank = dist.get_rank()
|
||||
for i, group in enumerate(grouped_ranks):
|
||||
if my_rank in group:
|
||||
return i
|
||||
raise RuntimeError
|
||||
|
||||
def get_moe_group(moe_expert_count):
|
||||
if dist.is_initialized():
|
||||
if not hasattr(get_moe_group, "_moe_groups"):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if world_size <= moe_expert_count:
|
||||
assert moe_expert_count % world_size == 0
|
||||
moe_groups = [[i] for i in range(world_size)]
|
||||
|
||||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
moe_groups = [
|
||||
[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||
for i in range(moe_expert_count)
|
||||
]
|
||||
|
||||
get_moe_group._moe_group_idx = moe_groups
|
||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
|
||||
return my_group_idx, get_moe_group._moe_groups[my_group_idx]
|
||||
|
||||
|
||||
def get_all2all_group(moe_expert_count):
|
||||
if dist.is_initialized():
|
||||
if not hasattr(get_all2all_group, "_all2all_groups"):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# more experts than world size
|
||||
if world_size <= moe_expert_count:
|
||||
assert moe_expert_count % world_size == 0
|
||||
all2all_groups = [[i for i in range(world_size)]]
|
||||
|
||||
# larger world than num experts
|
||||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
all2all_groups = [
|
||||
[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||
for i in range(ranks_per_group)
|
||||
]
|
||||
|
||||
get_all2all_group._all2all_group_idx = all2all_groups
|
||||
get_all2all_group._all2all_groups = [
|
||||
dist.new_group(g) for g in all2all_groups
|
||||
]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||
return get_all2all_group._all2all_groups[my_group_idx]
|
||||
|
||||
|
||||
|
||||
|
|
@ -18,6 +18,8 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
from .global_groups import get_all2all_group, get_moe_group
|
||||
|
||||
try:
|
||||
from fairseq.modules.moe import MOELayer
|
||||
|
||||
|
@ -61,64 +63,6 @@ class _AllToAll(torch.autograd.Function):
|
|||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
def _find_my_group_index(grouped_ranks):
|
||||
my_rank = dist.get_rank()
|
||||
for i, group in enumerate(grouped_ranks):
|
||||
if my_rank in group:
|
||||
return i
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
def get_moe_group(moe_expert_count):
|
||||
if dist.is_initialized():
|
||||
if not hasattr(get_moe_group, "_moe_groups"):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if world_size <= moe_expert_count:
|
||||
assert moe_expert_count % world_size == 0
|
||||
moe_groups = [[i] for i in range(world_size)]
|
||||
|
||||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
moe_groups = [
|
||||
[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||
for i in range(moe_expert_count)
|
||||
]
|
||||
|
||||
get_moe_group._moe_group_idx = moe_groups
|
||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_moe_group._moe_group_idx)
|
||||
return get_moe_group._moe_groups[my_group_idx]
|
||||
|
||||
|
||||
def get_all2all_group(moe_expert_count):
|
||||
if dist.is_initialized():
|
||||
if not hasattr(get_all2all_group, "_all2all_groups"):
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# more experts than world size
|
||||
if world_size <= moe_expert_count:
|
||||
assert moe_expert_count % world_size == 0
|
||||
all2all_groups = [[i for i in range(world_size)]]
|
||||
|
||||
# larger world than num experts
|
||||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
all2all_groups = [
|
||||
[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||
for i in range(ranks_per_group)
|
||||
]
|
||||
|
||||
get_all2all_group._all2all_group_idx = all2all_groups
|
||||
get_all2all_group._all2all_groups = [
|
||||
dist.new_group(g) for g in all2all_groups
|
||||
]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||
return get_all2all_group._all2all_groups[my_group_idx]
|
||||
|
||||
|
||||
class MOELayer(Base):
|
||||
|
@ -149,7 +93,7 @@ class MOELayer(Base):
|
|||
self.experts = cast(ModuleList, experts)
|
||||
else:
|
||||
self.experts = ModuleList([experts])
|
||||
self.expert_group = get_moe_group(args.moe_expert_count)
|
||||
_, self.expert_group = get_moe_group(args.moe_expert_count)
|
||||
self.all2all_group = get_all2all_group(args.moe_expert_count)
|
||||
self.world_size = dist.get_world_size(group=self.expert_group)
|
||||
self.all2all_size = dist.get_world_size(group=self.all2all_group)
|
||||
|
|
Loading…
Reference in New Issue
Block a user