flake8 lint checks
This commit is contained in:
parent
4714557e89
commit
994e4665a2
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.generate import cli_main
|
from fairseq_cli.generate import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.interactive import cli_main
|
from fairseq_cli.interactive import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
|
|
@ -33,4 +33,4 @@ for file in os.listdir(models_dir):
|
||||||
)
|
)
|
||||||
group_args = parser.add_argument_group("Additional command-line arguments")
|
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||||
MODEL_REGISTRY[model_name].add_args(group_args)
|
MODEL_REGISTRY[model_name].add_args(group_args)
|
||||||
globals()[model_name + "_parser"] = parser
|
globals()[model_name + "_parser"] = parser
|
||||||
|
|
|
@ -1,24 +1,21 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.distributed import fsdp_wrap
|
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||||
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
|
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.models.transformer import (
|
from fairseq.models.transformer import (
|
||||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||||
)
|
)
|
||||||
from fairseq.modules import PositionalEmbedding
|
from fairseq.modules import PositionalEmbedding
|
||||||
from fairseq.models.squad import SQuADHead
|
from fairseq.models.squad import SQuADHead
|
||||||
from torch import Tensor
|
|
||||||
from omegaconf import II
|
from omegaconf import II
|
||||||
from .machine_translation import MTEncoder as Encoder
|
from .machine_translation import MTEncoder as Encoder
|
||||||
from torchscale.architecture.config import EncoderConfig
|
from torchscale.architecture.config import EncoderConfig
|
||||||
|
@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BertConfig(FairseqDataclass):
|
class BertConfig(FairseqDataclass):
|
||||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass):
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
"help": (
|
||||||
|
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
|
@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass):
|
||||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||||
)
|
)
|
||||||
dummy_a2a: Optional[bool] = field(
|
dummy_a2a: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
default=False, metadata={
|
||||||
|
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||||
)
|
)
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||||
|
@ -202,7 +204,7 @@ class BertConfig(FairseqDataclass):
|
||||||
subln: Optional[bool] = field(
|
subln: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_model("mlm", dataclass=BertConfig)
|
@register_model("mlm", dataclass=BertConfig)
|
||||||
class BertModel(BaseFairseqModel):
|
class BertModel(BaseFairseqModel):
|
||||||
|
@ -245,9 +247,9 @@ class BertModel(BaseFairseqModel):
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
encoder = Encoder(
|
encoder = Encoder(
|
||||||
config,
|
config,
|
||||||
embed_tokens=embed_tokens,
|
embed_tokens=embed_tokens,
|
||||||
embed_positions=embed_positions,
|
embed_positions=embed_positions,
|
||||||
output_projection=lm_head,
|
output_projection=lm_head,
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
dictionary=task.dictionary,
|
dictionary=task.dictionary,
|
||||||
|
@ -259,14 +261,14 @@ class BertModel(BaseFairseqModel):
|
||||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||||
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
|
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
|
||||||
return embed_tokens
|
return embed_tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
|
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
|
||||||
return LMHead(embed_dim, output_dim, activation_fn, weight)
|
return LMHead(embed_dim, output_dim, activation_fn, weight)
|
||||||
|
|
||||||
def output_layer(self, features, masked_tokens=None):
|
def output_layer(self, features, masked_tokens=None):
|
||||||
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
||||||
|
|
||||||
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
||||||
"""Register a classification head."""
|
"""Register a classification head."""
|
||||||
if name in self.classification_heads:
|
if name in self.classification_heads:
|
||||||
|
@ -286,12 +288,12 @@ class BertModel(BaseFairseqModel):
|
||||||
self.args.pooler_activation_fn,
|
self.args.pooler_activation_fn,
|
||||||
self.args.pooler_dropout,
|
self.args.pooler_dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_question_answering_head(self, name, num_classes=None):
|
def register_question_answering_head(self, name, num_classes=None):
|
||||||
self.classification_heads[name] = SQuADHead(
|
self.classification_heads[name] = SQuADHead(
|
||||||
self.args.encoder_embed_dim,
|
self.args.encoder_embed_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def upgrade_state_dict_named(self, state_dict, name):
|
def upgrade_state_dict_named(self, state_dict, name):
|
||||||
prefix = name + '.' if name != '' else ''
|
prefix = name + '.' if name != '' else ''
|
||||||
|
|
||||||
|
@ -342,15 +344,16 @@ class BertModel(BaseFairseqModel):
|
||||||
if prefix + 'classification_heads.' + k not in state_dict:
|
if prefix + 'classification_heads.' + k not in state_dict:
|
||||||
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
||||||
state_dict[prefix + 'classification_heads.' + k] = v
|
state_dict[prefix + 'classification_heads.' + k] = v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src_tokens=None,
|
src_tokens=None,
|
||||||
features_only=False,
|
features_only=False,
|
||||||
return_all_hiddens=False,
|
return_all_hiddens=False,
|
||||||
classification_head_name=None,
|
classification_head_name=None,
|
||||||
masked_tokens=None,
|
masked_tokens=None,
|
||||||
**kwargs):
|
**kwargs
|
||||||
|
):
|
||||||
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
||||||
x, extra = encoder_out["encoder_out"], encoder_out
|
x, extra = encoder_out["encoder_out"], encoder_out
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
@ -362,7 +365,7 @@ class BertModel(BaseFairseqModel):
|
||||||
|
|
||||||
return x, extra
|
return x, extra
|
||||||
|
|
||||||
|
|
||||||
class ClassificationHead(nn.Module):
|
class ClassificationHead(nn.Module):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
|
@ -389,6 +392,7 @@ class ClassificationHead(nn.Module):
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LMHead(nn.Module):
|
class LMHead(nn.Module):
|
||||||
"""Head for masked language modeling."""
|
"""Head for masked language modeling."""
|
||||||
|
|
||||||
|
@ -459,4 +463,4 @@ def base_unilm_architecture(args):
|
||||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||||
args.offload_activations = getattr(args, "offload_activations", False)
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
if args.offload_activations:
|
if args.offload_activations:
|
||||||
args.checkpoint_activations = True
|
args.checkpoint_activations = True
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import math
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from fairseq import options, utils
|
from fairseq import utils
|
||||||
from fairseq import distributed_utils
|
from fairseq import distributed_utils
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
|
@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig
|
||||||
from omegaconf import II
|
from omegaconf import II
|
||||||
|
|
||||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LanguageConfig(FairseqDataclass):
|
class LanguageConfig(FairseqDataclass):
|
||||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||||
|
@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass):
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
"help": (
|
||||||
|
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
|
@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass):
|
||||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||||
)
|
)
|
||||||
dummy_a2a: Optional[bool] = field(
|
dummy_a2a: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
default=False, metadata={
|
||||||
|
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||||
)
|
)
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||||
|
@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel):
|
||||||
output_projection.weight = embed_tokens.weight
|
output_projection.weight = embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
output_projection = torch.nn.Linear(
|
output_projection = torch.nn.Linear(
|
||||||
decoder_embed_dim, len(task.dictionary), bias=False
|
args.decoder_embed_dim, len(task.dictionary), bias=False
|
||||||
)
|
)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -252,22 +256,23 @@ class LanguageModel(FairseqLanguageModel):
|
||||||
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
assert args.fp16_no_flatten_grads, \
|
||||||
|
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||||
|
|
||||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||||
|
|
||||||
config = DecoderConfig()
|
config = DecoderConfig()
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
decoder = LMDecoder(
|
decoder = LMDecoder(
|
||||||
config,
|
config,
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
embed_positions,
|
embed_positions,
|
||||||
output_projection,
|
output_projection,
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
dictionary=task.dictionary,
|
dictionary=task.dictionary,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(args, decoder)
|
return cls(args, decoder)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -283,7 +288,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||||
|
|
||||||
def max_positions(self):
|
def max_positions(self):
|
||||||
return self.embed_positions.max_positions
|
return self.embed_positions.max_positions
|
||||||
|
|
||||||
def reorder_incremental_state_scripting(
|
def reorder_incremental_state_scripting(
|
||||||
self,
|
self,
|
||||||
incremental_state,
|
incremental_state,
|
||||||
|
@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||||
result = incremental_state[module][key].index_select(0, new_order)
|
result = incremental_state[module][key].index_select(0, new_order)
|
||||||
incremental_state[module][key] = result
|
incremental_state[module][key] = result
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture("lm", "lm_base")
|
@register_model_architecture("lm", "lm_base")
|
||||||
def base_lm_architecture(args):
|
def base_lm_architecture(args):
|
||||||
# backward compatibility for older model checkpoints
|
# backward compatibility for older model checkpoints
|
||||||
|
@ -357,4 +363,3 @@ def base_lm_architecture(args):
|
||||||
args.offload_activations = getattr(args, "offload_activations", False)
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
if args.offload_activations:
|
if args.offload_activations:
|
||||||
args.checkpoint_activations = True
|
args.checkpoint_activations = True
|
||||||
|
|
||||||
|
|
|
@ -6,33 +6,20 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import functools
|
from typing import Dict, List, Optional, Tuple
|
||||||
import math
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.distributed import utils as dist_utils, fsdp_wrap
|
from fairseq.distributed import utils as fsdp_wrap
|
||||||
from fairseq import distributed_utils
|
from fairseq import distributed_utils
|
||||||
from fairseq import checkpoint_utils
|
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqEncoder,
|
FairseqEncoder,
|
||||||
FairseqEncoderDecoderModel,
|
FairseqEncoderDecoderModel,
|
||||||
FairseqIncrementalDecoder,
|
|
||||||
register_model,
|
register_model,
|
||||||
register_model_architecture,
|
register_model_architecture,
|
||||||
)
|
)
|
||||||
from fairseq.models.transformer import Embedding
|
from fairseq.models.transformer import Embedding
|
||||||
from fairseq.modules import (
|
from fairseq.modules import PositionalEmbedding
|
||||||
AdaptiveSoftmax,
|
|
||||||
FairseqDropout,
|
|
||||||
LayerDropModuleList,
|
|
||||||
LayerNorm,
|
|
||||||
PositionalEmbedding,
|
|
||||||
SinusoidalPositionalEmbedding,
|
|
||||||
)
|
|
||||||
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
||||||
from .language_modeling import LMDecoder as MTDecoder
|
from .language_modeling import LMDecoder as MTDecoder
|
||||||
|
@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
help="Use FP32 computations in MoE top2 gating function")
|
help="Use FP32 computations in MoE top2 gating function")
|
||||||
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
||||||
help="policy for second expert, options: all/sampling/random")
|
help="policy for second expert, options: all/sampling/random")
|
||||||
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
parser.add_argument(
|
||||||
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
|
'--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
||||||
|
help=(
|
||||||
|
"whether to normalize gate probs before or after dropping experts "
|
||||||
|
"for capacity and randomization"
|
||||||
|
)
|
||||||
|
)
|
||||||
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
||||||
help="MoE Expert FFN dimension")
|
help="MoE Expert FFN dimension")
|
||||||
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
||||||
help="Use top1 gate instead of top2")
|
help="Use top1 gate instead of top2")
|
||||||
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
parser.add_argument(
|
||||||
help="Fraction of tokens as capacity during validation" + \
|
'--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
||||||
"if set to negative, use same as training. range: (0.0, 1.0].")
|
help=(
|
||||||
|
"Fraction of tokens as capacity during validation"
|
||||||
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
|
)
|
||||||
|
)
|
||||||
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
||||||
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
||||||
|
|
||||||
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
||||||
help="Don't route padding tokens to any expert")
|
help="Don't route padding tokens to any expert")
|
||||||
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
||||||
|
@ -207,7 +202,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
||||||
if getattr(args, "max_target_positions", None) is None:
|
if getattr(args, "max_target_positions", None) is None:
|
||||||
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
||||||
|
|
||||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||||
|
|
||||||
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
||||||
|
@ -279,18 +274,18 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
|
|
||||||
encoder = cls.build_encoder(
|
encoder = cls.build_encoder(
|
||||||
args,
|
args,
|
||||||
encoder_embed_tokens,
|
encoder_embed_tokens,
|
||||||
encoder_embed_positions,
|
encoder_embed_positions,
|
||||||
src_dict,
|
src_dict,
|
||||||
)
|
)
|
||||||
decoder = cls.build_decoder(
|
decoder = cls.build_decoder(
|
||||||
args,
|
args,
|
||||||
decoder_embed_tokens,
|
decoder_embed_tokens,
|
||||||
decoder_embed_positions,
|
decoder_embed_positions,
|
||||||
output_projection,
|
output_projection,
|
||||||
tgt_dict,
|
tgt_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.share_all_embeddings:
|
if not args.share_all_embeddings:
|
||||||
min_params_to_wrap = getattr(
|
min_params_to_wrap = getattr(
|
||||||
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
||||||
|
@ -317,9 +312,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
return MTEncoder(
|
return MTEncoder(
|
||||||
config,
|
config,
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
embed_positions,
|
embed_positions,
|
||||||
is_encoder_decoder=True,
|
is_encoder_decoder=True,
|
||||||
dictionary=dictionary,
|
dictionary=dictionary,
|
||||||
)
|
)
|
||||||
|
@ -330,8 +325,8 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
return MTDecoder(
|
return MTDecoder(
|
||||||
config,
|
config,
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
embed_positions,
|
embed_positions,
|
||||||
output_projection,
|
output_projection,
|
||||||
is_encoder_decoder=True,
|
is_encoder_decoder=True,
|
||||||
|
@ -348,7 +343,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
encoder_out = self.encoder(
|
encoder_out = self.encoder(
|
||||||
src_tokens,
|
src_tokens,
|
||||||
return_all_hiddens=return_all_hiddens
|
return_all_hiddens=return_all_hiddens
|
||||||
)
|
)
|
||||||
decoder_out = self.decoder(
|
decoder_out = self.decoder(
|
||||||
|
@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder):
|
||||||
def max_positions(self):
|
def max_positions(self):
|
||||||
return self.embed_positions.max_positions
|
return self.embed_positions.max_positions
|
||||||
|
|
||||||
|
|
||||||
@register_model_architecture("mt", "mt_base")
|
@register_model_architecture("mt", "mt_base")
|
||||||
def base_architecture(args):
|
def base_architecture(args):
|
||||||
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
||||||
|
|
|
@ -32,4 +32,4 @@ for file in os.listdir(tasks_dir):
|
||||||
# fmt: on
|
# fmt: on
|
||||||
group_args = parser.add_argument_group("Additional command-line arguments")
|
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||||
TASK_REGISTRY[task_name].add_args(group_args)
|
TASK_REGISTRY[task_name].add_args(group_args)
|
||||||
globals()[task_name + "_parser"] = parser
|
globals()[task_name + "_parser"] = parser
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
from infinibatch.iterators import CheckpointableIterator
|
from infinibatch.iterators import CheckpointableIterator
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
class BaseBatchGen(CheckpointableIterator):
|
class BaseBatchGen(CheckpointableIterator):
|
||||||
"""
|
"""
|
||||||
This is a base class for batch generators that use infinibatch
|
This is a base class for batch generators that use infinibatch
|
||||||
|
@ -26,7 +23,7 @@ class BaseBatchGen(CheckpointableIterator):
|
||||||
Build infinibatch iterator and assign to self._iter
|
Build infinibatch iterator and assign to self._iter
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _move_to_tensor(self, batch):
|
def _move_to_tensor(self, batch):
|
||||||
|
|
||||||
def to_tensor(x):
|
def to_tensor(x):
|
||||||
|
@ -47,16 +44,16 @@ class BaseBatchGen(CheckpointableIterator):
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
return next(self._iter)
|
return next(self._iter)
|
||||||
|
|
||||||
def setstate(self, value):
|
def setstate(self, value):
|
||||||
self._iter.setstate(value)
|
self._iter.setstate(value)
|
||||||
|
|
||||||
def getstate(self):
|
def getstate(self):
|
||||||
return self._iter.getstate()
|
return self._iter.getstate()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._iter.close()
|
self._iter.close()
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return 819200000
|
return 819200000
|
||||||
|
|
||||||
|
@ -78,4 +75,4 @@ class BaseBatchGen(CheckpointableIterator):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first_batch(self):
|
def first_batch(self):
|
||||||
return "DUMMY"
|
return "DUMMY"
|
||||||
|
|
|
@ -1,13 +1,8 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import itertools
|
import itertools
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
@ -55,15 +50,15 @@ class MLMLoader(BaseBatchGen):
|
||||||
self.batch_read_ahead = args.batch_read_ahead
|
self.batch_read_ahead = args.batch_read_ahead
|
||||||
|
|
||||||
self._build_iter()
|
self._build_iter()
|
||||||
|
|
||||||
def _build_iter(self):
|
def _build_iter(self):
|
||||||
tokenized_lines = self._multilingual_tokenize()
|
tokenized_lines = self._multilingual_tokenize()
|
||||||
self.padded_batches = self._batchify(tokenized_lines)
|
self.padded_batches = self._batchify(tokenized_lines)
|
||||||
|
|
||||||
prefetch_batches = iterators.PrefetchIterator(
|
prefetch_batches = iterators.PrefetchIterator(
|
||||||
self.padded_batches,
|
self.padded_batches,
|
||||||
buffer_size=10000,
|
buffer_size=10000,
|
||||||
buffer_in_main_process=True,
|
buffer_in_main_process=True,
|
||||||
log_empty_buffer_warning=True and self.shard_id == 0,
|
log_empty_buffer_warning=True and self.shard_id == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,14 +80,14 @@ class MLMLoader(BaseBatchGen):
|
||||||
weights.append(float(data['weight']))
|
weights.append(float(data['weight']))
|
||||||
else:
|
else:
|
||||||
weights.append(int(data['count']))
|
weights.append(int(data['count']))
|
||||||
|
|
||||||
if len(multilingual_iters) == 1:
|
if len(multilingual_iters) == 1:
|
||||||
return multilingual_iters[0]
|
return multilingual_iters[0]
|
||||||
|
|
||||||
sampling_iterator = WeightIterator(weights)
|
sampling_iterator = WeightIterator(weights)
|
||||||
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
||||||
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
||||||
|
|
||||||
return tokenized_lines
|
return tokenized_lines
|
||||||
|
|
||||||
def _tokenize(self, data):
|
def _tokenize(self, data):
|
||||||
|
@ -109,7 +104,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
dataset = list(
|
dataset = list(
|
||||||
zip(
|
zip(
|
||||||
data['source'],
|
data['source'],
|
||||||
itertools.repeat(data['source_lang']),
|
itertools.repeat(data['source_lang']),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -117,27 +112,26 @@ class MLMLoader(BaseBatchGen):
|
||||||
chunk_files = \
|
chunk_files = \
|
||||||
iterators.InfinitePermutationSourceIterator(
|
iterators.InfinitePermutationSourceIterator(
|
||||||
dataset,
|
dataset,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
num_instances=self.num_shards,
|
num_instances=self.num_shards,
|
||||||
instance_rank=self.shard_id,
|
instance_rank=self.shard_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_files = \
|
chunk_files = \
|
||||||
iterators.ChunkedSourceIterator(
|
iterators.ChunkedSourceIterator(
|
||||||
dataset,
|
dataset,
|
||||||
num_instances=self.num_shards,
|
num_instances=self.num_shards,
|
||||||
instance_rank=self.shard_id,
|
instance_rank=self.shard_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
||||||
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
||||||
|
|
||||||
return tokenized_lines
|
return tokenized_lines
|
||||||
|
|
||||||
|
|
||||||
def _batchify(self, lines):
|
def _batchify(self, lines):
|
||||||
|
|
||||||
if self.max_sentences is not None:
|
if self.max_sentences is not None:
|
||||||
if self.batch_read_ahead > 0:
|
if self.batch_read_ahead > 0:
|
||||||
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
||||||
|
@ -145,14 +139,15 @@ class MLMLoader(BaseBatchGen):
|
||||||
else:
|
else:
|
||||||
def dynamic_batch_size(sample):
|
def dynamic_batch_size(sample):
|
||||||
lengths = [len(x) for x in sample]
|
lengths = [len(x) for x in sample]
|
||||||
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
|
batch_size = self.max_tokens // max(lengths)
|
||||||
|
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||||
return max(1, batch_size)
|
return max(1, batch_size)
|
||||||
|
|
||||||
batches = iterators.BucketedReadaheadBatchIterator(
|
batches = iterators.BucketedReadaheadBatchIterator(
|
||||||
lines,
|
lines,
|
||||||
read_ahead=self.batch_read_ahead,
|
read_ahead=self.batch_read_ahead,
|
||||||
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||||
batch_size=dynamic_batch_size,
|
batch_size=dynamic_batch_size,
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
@ -166,15 +161,15 @@ class MLMLoader(BaseBatchGen):
|
||||||
s2s_target_max_length = max([len(x[3]) for x in batch])
|
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||||
|
|
||||||
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
||||||
fill_value=self.dictionary.pad())
|
fill_value=self.dictionary.pad())
|
||||||
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
||||||
fill_value=self.dictionary.pad())
|
fill_value=self.dictionary.pad())
|
||||||
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
||||||
fill_value=self.dictionary.pad())
|
fill_value=self.dictionary.pad())
|
||||||
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||||
fill_value=self.dictionary.pad())
|
fill_value=self.dictionary.pad())
|
||||||
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||||
fill_value=self.dictionary.pad())
|
fill_value=self.dictionary.pad())
|
||||||
|
|
||||||
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
||||||
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
||||||
|
@ -182,7 +177,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
||||||
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
||||||
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
||||||
|
|
||||||
ret_batch = {
|
ret_batch = {
|
||||||
'net_input': {
|
'net_input': {
|
||||||
'src_tokens': mlm_source_ids.astype(np.int64),
|
'src_tokens': mlm_source_ids.astype(np.int64),
|
||||||
|
@ -199,16 +194,16 @@ class MLMLoader(BaseBatchGen):
|
||||||
)
|
)
|
||||||
|
|
||||||
return padded_batches
|
return padded_batches
|
||||||
|
|
||||||
def _prepare(self, _random, doc):
|
def _prepare(self, _random, doc):
|
||||||
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
|
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
|
||||||
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
|
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
|
||||||
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
|
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
|
||||||
|
|
||||||
def _mask_lm(self, _random, doc):
|
def _mask_lm(self, _random, doc):
|
||||||
def mask_tokens():
|
def mask_tokens():
|
||||||
return f"<mask>"
|
return "<mask>"
|
||||||
|
|
||||||
length = len(doc)
|
length = len(doc)
|
||||||
mask_tokens_num = int(length * self.args.mask_prob)
|
mask_tokens_num = int(length * self.args.mask_prob)
|
||||||
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
|
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
|
||||||
|
@ -222,11 +217,11 @@ class MLMLoader(BaseBatchGen):
|
||||||
# masked_tokens.append(nonmasked_tokens[position])
|
# masked_tokens.append(nonmasked_tokens[position])
|
||||||
masked_tokens[position] = nonmasked_tokens[position]
|
masked_tokens[position] = nonmasked_tokens[position]
|
||||||
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
|
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
|
||||||
|
|
||||||
return nonmasked_tokens, masked_tokens
|
return nonmasked_tokens, masked_tokens
|
||||||
|
|
||||||
def _span_corruption(self, _random, doc):
|
def _span_corruption(self, _random, doc):
|
||||||
|
|
||||||
def mask_tokens(i):
|
def mask_tokens(i):
|
||||||
return f"<mask_{i}>"
|
return f"<mask_{i}>"
|
||||||
|
|
||||||
|
@ -244,7 +239,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
_random.shuffle(possible_split_positions)
|
_random.shuffle(possible_split_positions)
|
||||||
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
||||||
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
||||||
|
|
||||||
possible_insert_positions = list(range(nonnoise_tokens_num))
|
possible_insert_positions = list(range(nonnoise_tokens_num))
|
||||||
_random.shuffle(possible_insert_positions)
|
_random.shuffle(possible_insert_positions)
|
||||||
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
|
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
|
||||||
|
@ -260,14 +255,14 @@ class MLMLoader(BaseBatchGen):
|
||||||
noise_spans.append(doc[start_pos:end_pos])
|
noise_spans.append(doc[start_pos:end_pos])
|
||||||
else:
|
else:
|
||||||
noise_spans.append([mask_id] + doc[start_pos:end_pos])
|
noise_spans.append([mask_id] + doc[start_pos:end_pos])
|
||||||
|
|
||||||
if getattr(self.args, "remove_source_sentinel", False):
|
if getattr(self.args, "remove_source_sentinel", False):
|
||||||
nonnoise_spans.extend(doc[last_end:start_pos])
|
nonnoise_spans.extend(doc[last_end:start_pos])
|
||||||
else:
|
else:
|
||||||
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
|
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
|
||||||
|
|
||||||
last_end = end_pos
|
last_end = end_pos
|
||||||
|
|
||||||
nonnoise_spans.extend(doc[last_end:])
|
nonnoise_spans.extend(doc[last_end:])
|
||||||
noise_spans = sum(noise_spans, [])
|
noise_spans = sum(noise_spans, [])
|
||||||
|
|
||||||
|
@ -276,10 +271,10 @@ class MLMLoader(BaseBatchGen):
|
||||||
def _read_from_files(self, source_file, source_lang):
|
def _read_from_files(self, source_file, source_lang):
|
||||||
# data = []
|
# data = []
|
||||||
file_path = os.path.join(self.data_dir, source_file)
|
file_path = os.path.join(self.data_dir, source_file)
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
print('| file {} not exists'.format(file_path), flush=True)
|
print('| file {} not exists'.format(file_path), flush=True)
|
||||||
return iter([]) # skip bad file
|
return iter([]) # skip bad file
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf8') as f:
|
with open(file_path, 'r', encoding='utf8') as f:
|
||||||
lines = f.read().strip().split('\n')
|
lines = f.read().strip().split('\n')
|
||||||
|
@ -292,7 +287,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
yield doc
|
yield doc
|
||||||
doc = [self.dictionary.bos()]
|
doc = [self.dictionary.bos()]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
||||||
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
||||||
|
|
||||||
|
@ -308,4 +303,4 @@ class MLMLoader(BaseBatchGen):
|
||||||
# data.append(doc)
|
# data.append(doc)
|
||||||
yield doc
|
yield doc
|
||||||
|
|
||||||
# return data
|
# return data
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import os
|
|
||||||
import gzip
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Dict, Iterable, Optional
|
||||||
import collections
|
import collections
|
||||||
from infinibatch import iterators
|
from infinibatch import iterators
|
||||||
|
|
||||||
|
|
||||||
def apply_to_sample(f, sample):
|
def apply_to_sample(f, sample):
|
||||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||||
return {}
|
return {}
|
||||||
|
@ -34,6 +33,7 @@ def apply_to_sample(f, sample):
|
||||||
|
|
||||||
return _apply(sample)
|
return _apply(sample)
|
||||||
|
|
||||||
|
|
||||||
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
def __init__(self, iterable: Iterable):
|
def __init__(self, iterable: Iterable):
|
||||||
self._input_iterable = iterable
|
self._input_iterable = iterable
|
||||||
|
@ -44,13 +44,16 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
|
|
||||||
def setstate(self, checkpoint: Optional[Dict]):
|
def setstate(self, checkpoint: Optional[Dict]):
|
||||||
self._iterator = iter(self._input_iterable)
|
self._iterator = iter(self._input_iterable)
|
||||||
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
|
self._num_items_yielded = iterators._advance_iterator(
|
||||||
|
self._iterator,
|
||||||
|
checkpoint['num_items_yielded']
|
||||||
|
) if checkpoint is not None else 0
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
item = next(self._iterator)
|
item = next(self._iterator)
|
||||||
self._num_items_yielded += 1
|
self._num_items_yielded += 1
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -61,17 +64,17 @@ class WeightIterator(object):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.control_index = list(range(len(weights)))
|
self.control_index = list(range(len(weights)))
|
||||||
self.setstate(None)
|
self.setstate(None)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def getstate(self):
|
def getstate(self):
|
||||||
return {"random_state": self._random_state}
|
return {"random_state": self._random_state}
|
||||||
|
|
||||||
def setstate(self, checkpoint):
|
def setstate(self, checkpoint):
|
||||||
self._random_state = checkpoint["random_state"] if checkpoint else None
|
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||||||
self._random = None # this will trigger the lazy initialization in self.__next__
|
self._random = None # this will trigger the lazy initialization in self.__next__
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
if self._random is None:
|
if self._random is None:
|
||||||
self._random = Random(self.seed)
|
self._random = Random(self.seed)
|
||||||
|
@ -80,6 +83,6 @@ class WeightIterator(object):
|
||||||
idx = self._random.choices(self.control_index, self.weights)[0]
|
idx = self._random.choices(self.control_index, self.weights)[0]
|
||||||
self._random_state = self._random.getstate()
|
self._random_state = self._random.getstate()
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -10,16 +10,13 @@ import logging
|
||||||
import os
|
import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import json
|
import json
|
||||||
from omegaconf import MISSING, II, OmegaConf
|
from omegaconf import MISSING, II
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.data import Dictionary
|
from fairseq.data import Dictionary
|
||||||
from fairseq.tasks import FairseqTask, register_task
|
from fairseq.tasks import FairseqTask, register_task
|
||||||
from .data.mlm_loader import MLMLoader
|
from .data.mlm_loader import MLMLoader
|
||||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -27,6 +24,7 @@ logger = logging.getLogger(__name__)
|
||||||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PretrainingConfig(FairseqDataclass):
|
class PretrainingConfig(FairseqDataclass):
|
||||||
data: str = field(
|
data: str = field(
|
||||||
|
@ -163,11 +161,11 @@ class PLMTask(FairseqTask):
|
||||||
'shuffle': True if split == 'train' else False,
|
'shuffle': True if split == 'train' else False,
|
||||||
}
|
}
|
||||||
self.datasets[split] = Namespace(**self.datasets[split])
|
self.datasets[split] = Namespace(**self.datasets[split])
|
||||||
|
|
||||||
def dataset(self, split):
|
def dataset(self, split):
|
||||||
if split not in self.datasets:
|
if split not in self.datasets:
|
||||||
raise KeyError("Dataset not loaded: " + split)
|
raise KeyError("Dataset not loaded: " + split)
|
||||||
|
|
||||||
return self.datasets[split]
|
return self.datasets[split]
|
||||||
|
|
||||||
def get_batch_iterator(
|
def get_batch_iterator(
|
||||||
|
@ -207,4 +205,4 @@ class PLMTask(FairseqTask):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def target_dictionary(self):
|
def target_dictionary(self):
|
||||||
return self.dictionary
|
return self.dictionary
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.train import cli_main
|
from fairseq_cli.train import cli_main
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli_main()
|
cli_main()
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
||||||
def grad_exists(p):
|
def grad_exists(p):
|
||||||
|
@ -75,4 +76,4 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
|
||||||
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
||||||
for g in grads + expert_grads + sharded_grads + base_expert_grads:
|
for g in grads + expert_grads + sharded_grads + base_expert_grads:
|
||||||
g.mul_(clip_coef)
|
g.mul_(clip_coef)
|
||||||
return total_norm
|
return total_norm
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -25,4 +25,4 @@ setup(
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Programming Language :: Python :: 3',
|
'Programming Language :: Python :: 3',
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -23,6 +23,7 @@ testcases = [
|
||||||
{"fsdp": True}
|
{"fsdp": True}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", testcases)
|
@pytest.mark.parametrize("args", testcases)
|
||||||
def test_decoder(args):
|
def test_decoder(args):
|
||||||
config = DecoderConfig(**args)
|
config = DecoderConfig(**args)
|
||||||
|
|
|
@ -23,6 +23,7 @@ testcases = [
|
||||||
{"fsdp": True}
|
{"fsdp": True}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", testcases)
|
@pytest.mark.parametrize("args", testcases)
|
||||||
def test_encoder(args):
|
def test_encoder(args):
|
||||||
config = EncoderConfig(**args)
|
config = EncoderConfig(**args)
|
||||||
|
|
|
@ -25,13 +25,14 @@ testcases = [
|
||||||
{"fsdp": True}
|
{"fsdp": True}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("args", testcases)
|
@pytest.mark.parametrize("args", testcases)
|
||||||
def test_decoder(args):
|
def test_decoder(args):
|
||||||
config = EncoderDecoderConfig(**args)
|
config = EncoderDecoderConfig(**args)
|
||||||
model = EncoderDecoder(
|
model = EncoderDecoder(
|
||||||
config,
|
config,
|
||||||
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||||
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||||
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
||||||
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
||||||
)
|
)
|
||||||
|
@ -41,6 +42,6 @@ def test_decoder(args):
|
||||||
|
|
||||||
model(
|
model(
|
||||||
src_tokens=src_tokens,
|
src_tokens=src_tokens,
|
||||||
prev_output_tokens=prev_output_tokens,
|
prev_output_tokens=prev_output_tokens,
|
||||||
features_only=True,
|
features_only=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
|
@ -355,7 +355,8 @@ def top2gating(
|
||||||
if has_tutel:
|
if has_tutel:
|
||||||
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||||
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||||
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
return l_aux, metadata, capacity, num_experts, \
|
||||||
|
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||||
|
|
||||||
# Store the capacity location for each token
|
# Store the capacity location for each token
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user