From a4d830b87d68da6c59db78900c77223371a280f2 Mon Sep 17 00:00:00 2001 From: johan bjorck Date: Mon, 24 Apr 2023 17:29:39 +0000 Subject: [PATCH] make num experts optional arg --- torchscale/component/xmoe/global_groups.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchscale/component/xmoe/global_groups.py b/torchscale/component/xmoe/global_groups.py index 1846913..3ee5752 100644 --- a/torchscale/component/xmoe/global_groups.py +++ b/torchscale/component/xmoe/global_groups.py @@ -8,7 +8,7 @@ def _find_my_group_index(grouped_ranks): return i raise RuntimeError -def get_moe_group(moe_expert_count): +def get_moe_group(moe_expert_count=None): if dist.is_initialized(): if not hasattr(get_moe_group, "_moe_groups"): world_size = dist.get_world_size() @@ -25,6 +25,7 @@ def get_moe_group(moe_expert_count): for i in range(moe_expert_count) ] + get_moe_group._moe_expert_count = moe_expert_count get_moe_group._moe_group_idx = moe_groups get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]