diff --git a/torchscale/component/xmoe/global_groups.py b/torchscale/component/xmoe/global_groups.py index 1846913..3ee5752 100644 --- a/torchscale/component/xmoe/global_groups.py +++ b/torchscale/component/xmoe/global_groups.py @@ -8,7 +8,7 @@ def _find_my_group_index(grouped_ranks): return i raise RuntimeError -def get_moe_group(moe_expert_count): +def get_moe_group(moe_expert_count=None): if dist.is_initialized(): if not hasattr(get_moe_group, "_moe_groups"): world_size = dist.get_world_size() @@ -25,6 +25,7 @@ def get_moe_group(moe_expert_count): for i in range(moe_expert_count) ] + get_moe_group._moe_expert_count = moe_expert_count get_moe_group._moe_group_idx = moe_groups get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]