flake8 lint checks

This commit is contained in:
shumingma 2022-11-26 08:10:15 -08:00
parent 4714557e89
commit 994e4665a2
28 changed files with 168 additions and 163 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models import models
import tasks import tasks

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models import models
import tasks import tasks

View File

@ -1,24 +1,21 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math
import logging import logging
from typing import Any, Dict, List, Optional from typing import Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.distributed import fsdp_wrap from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models.transformer import ( from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
) )
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from fairseq.models.squad import SQuADHead from fairseq.models.squad import SQuADHead
from torch import Tensor
from omegaconf import II from omegaconf import II
from .machine_translation import MTEncoder as Encoder from .machine_translation import MTEncoder as Encoder
from torchscale.architecture.config import EncoderConfig from torchscale.architecture.config import EncoderConfig
@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class BertConfig(FairseqDataclass): class BertConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass):
moe_eval_capacity_token_fraction: Optional[float] = field( moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25, default=0.25,
metadata={ metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." "help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
} }
) )
moe_normalize_expert_grad: Optional[str] = field( moe_normalize_expert_grad: Optional[str] = field(
@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass):
default=False, metadata={"help": "records all to all perf stats during distributed training"} default=False, metadata={"help": "records all to all perf stats during distributed training"}
) )
dummy_a2a: Optional[bool] = field( dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
) )
moe_batch_prioritized_routing: Optional[bool] = field( moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
@ -350,7 +352,8 @@ class BertModel(BaseFairseqModel):
return_all_hiddens=False, return_all_hiddens=False,
classification_head_name=None, classification_head_name=None,
masked_tokens=None, masked_tokens=None,
**kwargs): **kwargs
):
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens) encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
x, extra = encoder_out["encoder_out"], encoder_out x, extra = encoder_out["encoder_out"], encoder_out
x = x.transpose(0, 1) x = x.transpose(0, 1)
@ -389,6 +392,7 @@ class ClassificationHead(nn.Module):
x = self.out_proj(x) x = self.out_proj(x)
return x return x
class LMHead(nn.Module): class LMHead(nn.Module):
"""Head for masked language modeling.""" """Head for masked language modeling."""

View File

@ -6,12 +6,12 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import torch import torch
from fairseq import options, utils from fairseq import utils
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import ( from fairseq.models import (
@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig
from omegaconf import II from omegaconf import II
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class LanguageConfig(FairseqDataclass): class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass):
moe_eval_capacity_token_fraction: Optional[float] = field( moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25, default=0.25,
metadata={ metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." "help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
} }
) )
moe_normalize_expert_grad: Optional[str] = field( moe_normalize_expert_grad: Optional[str] = field(
@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass):
default=False, metadata={"help": "records all to all perf stats during distributed training"} default=False, metadata={"help": "records all to all perf stats during distributed training"}
) )
dummy_a2a: Optional[bool] = field( dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
) )
moe_batch_prioritized_routing: Optional[bool] = field( moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel):
output_projection.weight = embed_tokens.weight output_projection.weight = embed_tokens.weight
else: else:
output_projection = torch.nn.Linear( output_projection = torch.nn.Linear(
decoder_embed_dim, len(task.dictionary), bias=False args.decoder_embed_dim, len(task.dictionary), bias=False
) )
torch.nn.init.normal_( torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5 output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
) )
if ( if (
@ -252,7 +256,8 @@ class LanguageModel(FairseqLanguageModel):
and getattr(args, 'ddp_backend', None) != "fully_sharded" and getattr(args, 'ddp_backend', None) != "fully_sharded"
) )
): ):
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" assert args.fp16_no_flatten_grads, \
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank() args.ddp_rank = distributed_utils.get_data_parallel_rank()
@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
result = incremental_state[module][key].index_select(0, new_order) result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result incremental_state[module][key] = result
@register_model_architecture("lm", "lm_base") @register_model_architecture("lm", "lm_base")
def base_lm_architecture(args): def base_lm_architecture(args):
# backward compatibility for older model checkpoints # backward compatibility for older model checkpoints
@ -357,4 +363,3 @@ def base_lm_architecture(args):
args.offload_activations = getattr(args, "offload_activations", False) args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations: if args.offload_activations:
args.checkpoint_activations = True args.checkpoint_activations = True

View File

@ -6,33 +6,20 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools from typing import Dict, List, Optional, Tuple
import math
from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn
from fairseq import utils from fairseq import utils
from fairseq.distributed import utils as dist_utils, fsdp_wrap from fairseq.distributed import utils as fsdp_wrap
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq import checkpoint_utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel, FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.transformer import Embedding from fairseq.models.transformer import Embedding
from fairseq.modules import ( from fairseq.modules import PositionalEmbedding
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder from .language_modeling import LMDecoder as MTDecoder
@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel):
help="Use FP32 computations in MoE top2 gating function") help="Use FP32 computations in MoE top2 gating function")
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling', parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
help="policy for second expert, options: all/sampling/random") help="policy for second expert, options: all/sampling/random")
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', parser.add_argument(
help="whether to normalize gate probs before or after dropping experts for capacity and randomization") '--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help=(
"whether to normalize gate probs before or after dropping experts "
"for capacity and randomization"
)
)
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0, parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
help="MoE Expert FFN dimension") help="MoE Expert FFN dimension")
parser.add_argument('--moe-top1-expert', default=False, action='store_true', parser.add_argument('--moe-top1-expert', default=False, action='store_true',
help="Use top1 gate instead of top2") help="Use top1 gate instead of top2")
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25, parser.add_argument(
help="Fraction of tokens as capacity during validation" + \ '--moe-eval-capacity-token-fraction', type=float, default=0.25,
"if set to negative, use same as training. range: (0.0, 1.0].") help=(
"Fraction of tokens as capacity during validation"
"if set to negative, use same as training. range: (0.0, 1.0]."
)
)
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size', parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'") help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true', parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
help="Don't route padding tokens to any expert") help="Don't route padding tokens to any expert")
parser.add_argument('--use-xmoe', default=False, action='store_true', parser.add_argument('--use-xmoe', default=False, action='store_true',
@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder):
def max_positions(self): def max_positions(self):
return self.embed_positions.max_positions return self.embed_positions.max_positions
@register_model_architecture("mt", "mt_base") @register_model_architecture("mt", "mt_base")
def base_architecture(args): def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_path = getattr(args, "encoder_embed_path", None)

View File

@ -1,14 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math
import re
import sys
import time
import torch import torch
from infinibatch.iterators import CheckpointableIterator from infinibatch.iterators import CheckpointableIterator
from . import utils from . import utils
class BaseBatchGen(CheckpointableIterator): class BaseBatchGen(CheckpointableIterator):
""" """
This is a base class for batch generators that use infinibatch This is a base class for batch generators that use infinibatch

View File

@ -1,13 +1,8 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import glob
import os import os
import torch
import numpy as np import numpy as np
import time
import json
import random
import itertools import itertools
import copy import copy
@ -135,7 +130,6 @@ class MLMLoader(BaseBatchGen):
return tokenized_lines return tokenized_lines
def _batchify(self, lines): def _batchify(self, lines):
if self.max_sentences is not None: if self.max_sentences is not None:
@ -145,7 +139,8 @@ class MLMLoader(BaseBatchGen):
else: else:
def dynamic_batch_size(sample): def dynamic_batch_size(sample):
lengths = [len(x) for x in sample] lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple batch_size = self.max_tokens // max(lengths)
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
return max(1, batch_size) return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator( batches = iterators.BucketedReadaheadBatchIterator(
@ -166,15 +161,15 @@ class MLMLoader(BaseBatchGen):
s2s_target_max_length = max([len(x[3]) for x in batch]) s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32, mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32, mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32, s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch): for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
@ -207,7 +202,7 @@ class MLMLoader(BaseBatchGen):
def _mask_lm(self, _random, doc): def _mask_lm(self, _random, doc):
def mask_tokens(): def mask_tokens():
return f"<mask>" return "<mask>"
length = len(doc) length = len(doc)
mask_tokens_num = int(length * self.args.mask_prob) mask_tokens_num = int(length * self.args.mask_prob)
@ -279,7 +274,7 @@ class MLMLoader(BaseBatchGen):
if not os.path.exists(file_path): if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True) print('| file {} not exists'.format(file_path), flush=True)
return iter([]) # skip bad file return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f: with open(file_path, 'r', encoding='utf8') as f:
lines = f.read().strip().split('\n') lines = f.read().strip().split('\n')

View File

@ -1,14 +1,13 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import os
import gzip
import numpy as np import numpy as np
from random import Random from random import Random
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterable, Optional
import collections import collections
from infinibatch import iterators from infinibatch import iterators
def apply_to_sample(f, sample): def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0: if hasattr(sample, "__len__") and len(sample) == 0:
return {} return {}
@ -34,6 +33,7 @@ def apply_to_sample(f, sample):
return _apply(sample) return _apply(sample)
class NativeCheckpointableIterator(iterators.CheckpointableIterator): class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def __init__(self, iterable: Iterable): def __init__(self, iterable: Iterable):
self._input_iterable = iterable self._input_iterable = iterable
@ -44,7 +44,10 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def setstate(self, checkpoint: Optional[Dict]): def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable) self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0 self._num_items_yielded = iterators._advance_iterator(
self._iterator,
checkpoint['num_items_yielded']
) if checkpoint is not None else 0
def __next__(self): def __next__(self):
item = next(self._iterator) item = next(self._iterator)

View File

@ -10,16 +10,13 @@ import logging
import os import os
from argparse import Namespace from argparse import Namespace
import json import json
from omegaconf import MISSING, II, OmegaConf from omegaconf import MISSING, II
from typing import Any
import numpy as np
from fairseq import utils from fairseq import utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import FairseqTask, register_task
from .data.mlm_loader import MLMLoader from .data.mlm_loader import MLMLoader
from fairseq.dataclass import FairseqDataclass, ChoiceEnum from fairseq.dataclass import FairseqDataclass, ChoiceEnum
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
import sentencepiece as spm import sentencepiece as spm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,6 +24,7 @@ logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass @dataclass
class PretrainingConfig(FairseqDataclass): class PretrainingConfig(FairseqDataclass):
data: str = field( data: str = field(

View File

@ -1,11 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.train import cli_main from fairseq_cli.train import cli_main
if __name__ == "__main__": if __name__ == "__main__":
cli_main() cli_main()

View File

@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
import torch.distributed as dist import torch.distributed as dist
import math import math
@torch.no_grad() @torch.no_grad()
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor: def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p): def grad_exists(p):

View File

@ -23,6 +23,7 @@ testcases = [
{"fsdp": True} {"fsdp": True}
] ]
@pytest.mark.parametrize("args", testcases) @pytest.mark.parametrize("args", testcases)
def test_decoder(args): def test_decoder(args):
config = DecoderConfig(**args) config = DecoderConfig(**args)

View File

@ -23,6 +23,7 @@ testcases = [
{"fsdp": True} {"fsdp": True}
] ]
@pytest.mark.parametrize("args", testcases) @pytest.mark.parametrize("args", testcases)
def test_encoder(args): def test_encoder(args):
config = EncoderConfig(**args) config = EncoderConfig(**args)

View File

@ -25,6 +25,7 @@ testcases = [
{"fsdp": True} {"fsdp": True}
] ]
@pytest.mark.parametrize("args", testcases) @pytest.mark.parametrize("args", testcases)
def test_decoder(args): def test_decoder(args):
config = EncoderDecoderConfig(**args) config = EncoderDecoderConfig(**args)

View File

@ -355,7 +355,8 @@ def top2gating(
if has_tutel: if has_tutel:
locations1_s = torch.sum(locations1 * mask1_, dim=1) locations1_s = torch.sum(locations1 * mask1_, dim=1)
locations2_s = torch.sum(locations2 * mask2_, dim=1) locations2_s = torch.sum(locations2 * mask2_, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s] return l_aux, metadata, capacity, num_experts, \
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
# Store the capacity location for each token # Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)