make num experts optional arg

This commit is contained in:
johan bjorck 2023-04-24 17:29:39 +00:00
parent 886c8ab408
commit a4d830b87d

View File

@ -8,7 +8,7 @@ def _find_my_group_index(grouped_ranks):
return i return i
raise RuntimeError raise RuntimeError
def get_moe_group(moe_expert_count): def get_moe_group(moe_expert_count=None):
if dist.is_initialized(): if dist.is_initialized():
if not hasattr(get_moe_group, "_moe_groups"): if not hasattr(get_moe_group, "_moe_groups"):
world_size = dist.get_world_size() world_size = dist.get_world_size()
@ -25,6 +25,7 @@ def get_moe_group(moe_expert_count):
for i in range(moe_expert_count) for i in range(moe_expert_count)
] ]
get_moe_group._moe_expert_count = moe_expert_count
get_moe_group._moe_group_idx = moe_groups get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]