make num experts optional arg
This commit is contained in:
parent
886c8ab408
commit
a4d830b87d
|
@ -8,7 +8,7 @@ def _find_my_group_index(grouped_ranks):
|
||||||
return i
|
return i
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
def get_moe_group(moe_expert_count):
|
def get_moe_group(moe_expert_count=None):
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
if not hasattr(get_moe_group, "_moe_groups"):
|
if not hasattr(get_moe_group, "_moe_groups"):
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
@ -25,6 +25,7 @@ def get_moe_group(moe_expert_count):
|
||||||
for i in range(moe_expert_count)
|
for i in range(moe_expert_count)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
get_moe_group._moe_expert_count = moe_expert_count
|
||||||
get_moe_group._moe_group_idx = moe_groups
|
get_moe_group._moe_group_idx = moe_groups
|
||||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user