flake8 lint checks
This commit is contained in:
parent
4714557e89
commit
994e4665a2
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.generate import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
cli_main()
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.interactive import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
cli_main()
|
||||
|
|
|
@ -33,4 +33,4 @@ for file in os.listdir(models_dir):
|
|||
)
|
||||
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||
MODEL_REGISTRY[model_name].add_args(group_args)
|
||||
globals()[model_name + "_parser"] = parser
|
||||
globals()[model_name + "_parser"] = parser
|
||||
|
|
|
@ -1,24 +1,21 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import fsdp_wrap
|
||||
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
|
||||
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models.transformer import (
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||
)
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from fairseq.models.squad import SQuADHead
|
||||
from torch import Tensor
|
||||
from omegaconf import II
|
||||
from .machine_translation import MTEncoder as Encoder
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
|
@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertConfig(FairseqDataclass):
|
||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||
|
@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass):
|
|||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||
default=0.25,
|
||||
metadata={
|
||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
"help": (
|
||||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
)
|
||||
}
|
||||
)
|
||||
moe_normalize_expert_grad: Optional[str] = field(
|
||||
|
@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass):
|
|||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||
)
|
||||
dummy_a2a: Optional[bool] = field(
|
||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
default=False, metadata={
|
||||
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
)
|
||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||
|
@ -202,7 +204,7 @@ class BertConfig(FairseqDataclass):
|
|||
subln: Optional[bool] = field(
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@register_model("mlm", dataclass=BertConfig)
|
||||
class BertModel(BaseFairseqModel):
|
||||
|
@ -245,9 +247,9 @@ class BertModel(BaseFairseqModel):
|
|||
config.override(args)
|
||||
|
||||
encoder = Encoder(
|
||||
config,
|
||||
embed_tokens=embed_tokens,
|
||||
embed_positions=embed_positions,
|
||||
config,
|
||||
embed_tokens=embed_tokens,
|
||||
embed_positions=embed_positions,
|
||||
output_projection=lm_head,
|
||||
is_encoder_decoder=False,
|
||||
dictionary=task.dictionary,
|
||||
|
@ -259,14 +261,14 @@ class BertModel(BaseFairseqModel):
|
|||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
|
||||
return embed_tokens
|
||||
|
||||
|
||||
@classmethod
|
||||
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
|
||||
return LMHead(embed_dim, output_dim, activation_fn, weight)
|
||||
|
||||
|
||||
def output_layer(self, features, masked_tokens=None):
|
||||
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
||||
|
||||
|
||||
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
||||
"""Register a classification head."""
|
||||
if name in self.classification_heads:
|
||||
|
@ -286,12 +288,12 @@ class BertModel(BaseFairseqModel):
|
|||
self.args.pooler_activation_fn,
|
||||
self.args.pooler_dropout,
|
||||
)
|
||||
|
||||
|
||||
def register_question_answering_head(self, name, num_classes=None):
|
||||
self.classification_heads[name] = SQuADHead(
|
||||
self.args.encoder_embed_dim,
|
||||
)
|
||||
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + '.' if name != '' else ''
|
||||
|
||||
|
@ -342,15 +344,16 @@ class BertModel(BaseFairseqModel):
|
|||
if prefix + 'classification_heads.' + k not in state_dict:
|
||||
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
||||
state_dict[prefix + 'classification_heads.' + k] = v
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens=None,
|
||||
self,
|
||||
src_tokens=None,
|
||||
features_only=False,
|
||||
return_all_hiddens=False,
|
||||
classification_head_name=None,
|
||||
classification_head_name=None,
|
||||
masked_tokens=None,
|
||||
**kwargs):
|
||||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
||||
x, extra = encoder_out["encoder_out"], encoder_out
|
||||
x = x.transpose(0, 1)
|
||||
|
@ -362,7 +365,7 @@ class BertModel(BaseFairseqModel):
|
|||
|
||||
return x, extra
|
||||
|
||||
|
||||
|
||||
class ClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
|
@ -389,6 +392,7 @@ class ClassificationHead(nn.Module):
|
|||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LMHead(nn.Module):
|
||||
"""Head for masked language modeling."""
|
||||
|
||||
|
@ -459,4 +463,4 @@ def base_unilm_architecture(args):
|
|||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||
args.offload_activations = getattr(args, "offload_activations", False)
|
||||
if args.offload_activations:
|
||||
args.checkpoint_activations = True
|
||||
args.checkpoint_activations = True
|
||||
|
|
|
@ -6,12 +6,12 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq import utils
|
||||
from fairseq import distributed_utils
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import (
|
||||
|
@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig
|
|||
from omegaconf import II
|
||||
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LanguageConfig(FairseqDataclass):
|
||||
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
||||
|
@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass):
|
|||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||
default=0.25,
|
||||
metadata={
|
||||
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
"help": (
|
||||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
)
|
||||
}
|
||||
)
|
||||
moe_normalize_expert_grad: Optional[str] = field(
|
||||
|
@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass):
|
|||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||
)
|
||||
dummy_a2a: Optional[bool] = field(
|
||||
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
default=False, metadata={
|
||||
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
)
|
||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||
|
@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel):
|
|||
output_projection.weight = embed_tokens.weight
|
||||
else:
|
||||
output_projection = torch.nn.Linear(
|
||||
decoder_embed_dim, len(task.dictionary), bias=False
|
||||
args.decoder_embed_dim, len(task.dictionary), bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -252,22 +256,23 @@ class LanguageModel(FairseqLanguageModel):
|
|||
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
||||
)
|
||||
):
|
||||
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||
|
||||
assert args.fp16_no_flatten_grads, \
|
||||
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||
|
||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||
|
||||
config = DecoderConfig()
|
||||
config.override(args)
|
||||
|
||||
decoder = LMDecoder(
|
||||
config,
|
||||
embed_tokens,
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
output_projection,
|
||||
is_encoder_decoder=False,
|
||||
dictionary=task.dictionary,
|
||||
)
|
||||
|
||||
|
||||
return cls(args, decoder)
|
||||
|
||||
@classmethod
|
||||
|
@ -283,7 +288,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
|||
|
||||
def max_positions(self):
|
||||
return self.embed_positions.max_positions
|
||||
|
||||
|
||||
def reorder_incremental_state_scripting(
|
||||
self,
|
||||
incremental_state,
|
||||
|
@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
|||
result = incremental_state[module][key].index_select(0, new_order)
|
||||
incremental_state[module][key] = result
|
||||
|
||||
|
||||
@register_model_architecture("lm", "lm_base")
|
||||
def base_lm_architecture(args):
|
||||
# backward compatibility for older model checkpoints
|
||||
|
@ -357,4 +363,3 @@ def base_lm_architecture(args):
|
|||
args.offload_activations = getattr(args, "offload_activations", False)
|
||||
if args.offload_activations:
|
||||
args.checkpoint_activations = True
|
||||
|
||||
|
|
|
@ -6,33 +6,20 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import utils as dist_utils, fsdp_wrap
|
||||
from fairseq.distributed import utils as fsdp_wrap
|
||||
from fairseq import distributed_utils
|
||||
from fairseq import checkpoint_utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import Embedding
|
||||
from fairseq.modules import (
|
||||
AdaptiveSoftmax,
|
||||
FairseqDropout,
|
||||
LayerDropModuleList,
|
||||
LayerNorm,
|
||||
PositionalEmbedding,
|
||||
SinusoidalPositionalEmbedding,
|
||||
)
|
||||
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
||||
from .language_modeling import LMDecoder as MTDecoder
|
||||
|
@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
help="Use FP32 computations in MoE top2 gating function")
|
||||
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
|
||||
help="policy for second expert, options: all/sampling/random")
|
||||
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
||||
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
|
||||
parser.add_argument(
|
||||
'--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
|
||||
help=(
|
||||
"whether to normalize gate probs before or after dropping experts "
|
||||
"for capacity and randomization"
|
||||
)
|
||||
)
|
||||
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
|
||||
help="MoE Expert FFN dimension")
|
||||
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
|
||||
help="Use top1 gate instead of top2")
|
||||
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
||||
help="Fraction of tokens as capacity during validation" + \
|
||||
"if set to negative, use same as training. range: (0.0, 1.0].")
|
||||
parser.add_argument(
|
||||
'--moe-eval-capacity-token-fraction', type=float, default=0.25,
|
||||
help=(
|
||||
"Fraction of tokens as capacity during validation"
|
||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
)
|
||||
)
|
||||
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
|
||||
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
|
||||
|
||||
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
|
||||
help="Don't route padding tokens to any expert")
|
||||
parser.add_argument('--use-xmoe', default=False, action='store_true',
|
||||
|
@ -207,7 +202,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
||||
if getattr(args, "max_target_positions", None) is None:
|
||||
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
||||
|
||||
|
||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||
|
||||
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
||||
|
@ -279,18 +274,18 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
|
||||
encoder = cls.build_encoder(
|
||||
args,
|
||||
encoder_embed_tokens,
|
||||
encoder_embed_tokens,
|
||||
encoder_embed_positions,
|
||||
src_dict,
|
||||
src_dict,
|
||||
)
|
||||
decoder = cls.build_decoder(
|
||||
args,
|
||||
args,
|
||||
decoder_embed_tokens,
|
||||
decoder_embed_positions,
|
||||
output_projection,
|
||||
tgt_dict,
|
||||
)
|
||||
|
||||
|
||||
if not args.share_all_embeddings:
|
||||
min_params_to_wrap = getattr(
|
||||
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
||||
|
@ -317,9 +312,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
config.override(args)
|
||||
|
||||
return MTEncoder(
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
is_encoder_decoder=True,
|
||||
dictionary=dictionary,
|
||||
)
|
||||
|
@ -330,8 +325,8 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
config.override(args)
|
||||
|
||||
return MTDecoder(
|
||||
config,
|
||||
embed_tokens,
|
||||
config,
|
||||
embed_tokens,
|
||||
embed_positions,
|
||||
output_projection,
|
||||
is_encoder_decoder=True,
|
||||
|
@ -348,7 +343,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
src_tokens,
|
||||
return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
|
@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder):
|
|||
def max_positions(self):
|
||||
return self.embed_positions.max_positions
|
||||
|
||||
|
||||
@register_model_architecture("mt", "mt_base")
|
||||
def base_architecture(args):
|
||||
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
||||
|
|
|
@ -32,4 +32,4 @@ for file in os.listdir(tasks_dir):
|
|||
# fmt: on
|
||||
group_args = parser.add_argument_group("Additional command-line arguments")
|
||||
TASK_REGISTRY[task_name].add_args(group_args)
|
||||
globals()[task_name + "_parser"] = parser
|
||||
globals()[task_name + "_parser"] = parser
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
from infinibatch.iterators import CheckpointableIterator
|
||||
from . import utils
|
||||
|
||||
|
||||
class BaseBatchGen(CheckpointableIterator):
|
||||
"""
|
||||
This is a base class for batch generators that use infinibatch
|
||||
|
@ -26,7 +23,7 @@ class BaseBatchGen(CheckpointableIterator):
|
|||
Build infinibatch iterator and assign to self._iter
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _move_to_tensor(self, batch):
|
||||
|
||||
def to_tensor(x):
|
||||
|
@ -47,16 +44,16 @@ class BaseBatchGen(CheckpointableIterator):
|
|||
|
||||
def __next__(self):
|
||||
return next(self._iter)
|
||||
|
||||
|
||||
def setstate(self, value):
|
||||
self._iter.setstate(value)
|
||||
|
||||
|
||||
def getstate(self):
|
||||
return self._iter.getstate()
|
||||
|
||||
|
||||
def close(self):
|
||||
self._iter.close()
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return 819200000
|
||||
|
||||
|
@ -78,4 +75,4 @@ class BaseBatchGen(CheckpointableIterator):
|
|||
|
||||
@property
|
||||
def first_batch(self):
|
||||
return "DUMMY"
|
||||
return "DUMMY"
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import glob
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import itertools
|
||||
import copy
|
||||
|
||||
|
@ -55,15 +50,15 @@ class MLMLoader(BaseBatchGen):
|
|||
self.batch_read_ahead = args.batch_read_ahead
|
||||
|
||||
self._build_iter()
|
||||
|
||||
|
||||
def _build_iter(self):
|
||||
tokenized_lines = self._multilingual_tokenize()
|
||||
self.padded_batches = self._batchify(tokenized_lines)
|
||||
|
||||
|
||||
prefetch_batches = iterators.PrefetchIterator(
|
||||
self.padded_batches,
|
||||
buffer_size=10000,
|
||||
buffer_in_main_process=True,
|
||||
self.padded_batches,
|
||||
buffer_size=10000,
|
||||
buffer_in_main_process=True,
|
||||
log_empty_buffer_warning=True and self.shard_id == 0,
|
||||
)
|
||||
|
||||
|
@ -85,14 +80,14 @@ class MLMLoader(BaseBatchGen):
|
|||
weights.append(float(data['weight']))
|
||||
else:
|
||||
weights.append(int(data['count']))
|
||||
|
||||
|
||||
if len(multilingual_iters) == 1:
|
||||
return multilingual_iters[0]
|
||||
|
||||
sampling_iterator = WeightIterator(weights)
|
||||
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
||||
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
||||
|
||||
|
||||
return tokenized_lines
|
||||
|
||||
def _tokenize(self, data):
|
||||
|
@ -109,7 +104,7 @@ class MLMLoader(BaseBatchGen):
|
|||
dataset = list(
|
||||
zip(
|
||||
data['source'],
|
||||
itertools.repeat(data['source_lang']),
|
||||
itertools.repeat(data['source_lang']),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -117,27 +112,26 @@ class MLMLoader(BaseBatchGen):
|
|||
chunk_files = \
|
||||
iterators.InfinitePermutationSourceIterator(
|
||||
dataset,
|
||||
seed=self.seed,
|
||||
shuffle=self.shuffle,
|
||||
num_instances=self.num_shards,
|
||||
seed=self.seed,
|
||||
shuffle=self.shuffle,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
else:
|
||||
chunk_files = \
|
||||
iterators.ChunkedSourceIterator(
|
||||
dataset,
|
||||
num_instances=self.num_shards,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
|
||||
|
||||
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
||||
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
||||
|
||||
|
||||
return tokenized_lines
|
||||
|
||||
|
||||
def _batchify(self, lines):
|
||||
|
||||
|
||||
if self.max_sentences is not None:
|
||||
if self.batch_read_ahead > 0:
|
||||
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
||||
|
@ -145,14 +139,15 @@ class MLMLoader(BaseBatchGen):
|
|||
else:
|
||||
def dynamic_batch_size(sample):
|
||||
lengths = [len(x) for x in sample]
|
||||
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||
batch_size = self.max_tokens // max(lengths)
|
||||
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||
return max(1, batch_size)
|
||||
|
||||
|
||||
batches = iterators.BucketedReadaheadBatchIterator(
|
||||
lines,
|
||||
read_ahead=self.batch_read_ahead,
|
||||
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||
batch_size=dynamic_batch_size,
|
||||
read_ahead=self.batch_read_ahead,
|
||||
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||
batch_size=dynamic_batch_size,
|
||||
shuffle=self.shuffle,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
@ -166,15 +161,15 @@ class MLMLoader(BaseBatchGen):
|
|||
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||
|
||||
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
fill_value=self.dictionary.pad())
|
||||
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
fill_value=self.dictionary.pad())
|
||||
|
||||
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
||||
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
||||
|
@ -182,7 +177,7 @@ class MLMLoader(BaseBatchGen):
|
|||
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
||||
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
||||
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
||||
|
||||
|
||||
ret_batch = {
|
||||
'net_input': {
|
||||
'src_tokens': mlm_source_ids.astype(np.int64),
|
||||
|
@ -199,16 +194,16 @@ class MLMLoader(BaseBatchGen):
|
|||
)
|
||||
|
||||
return padded_batches
|
||||
|
||||
|
||||
def _prepare(self, _random, doc):
|
||||
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
|
||||
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
|
||||
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
|
||||
|
||||
|
||||
def _mask_lm(self, _random, doc):
|
||||
def mask_tokens():
|
||||
return f"<mask>"
|
||||
|
||||
return "<mask>"
|
||||
|
||||
length = len(doc)
|
||||
mask_tokens_num = int(length * self.args.mask_prob)
|
||||
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
|
||||
|
@ -222,11 +217,11 @@ class MLMLoader(BaseBatchGen):
|
|||
# masked_tokens.append(nonmasked_tokens[position])
|
||||
masked_tokens[position] = nonmasked_tokens[position]
|
||||
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
|
||||
|
||||
|
||||
return nonmasked_tokens, masked_tokens
|
||||
|
||||
def _span_corruption(self, _random, doc):
|
||||
|
||||
|
||||
def mask_tokens(i):
|
||||
return f"<mask_{i}>"
|
||||
|
||||
|
@ -244,7 +239,7 @@ class MLMLoader(BaseBatchGen):
|
|||
_random.shuffle(possible_split_positions)
|
||||
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
||||
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
||||
|
||||
|
||||
possible_insert_positions = list(range(nonnoise_tokens_num))
|
||||
_random.shuffle(possible_insert_positions)
|
||||
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
|
||||
|
@ -260,14 +255,14 @@ class MLMLoader(BaseBatchGen):
|
|||
noise_spans.append(doc[start_pos:end_pos])
|
||||
else:
|
||||
noise_spans.append([mask_id] + doc[start_pos:end_pos])
|
||||
|
||||
|
||||
if getattr(self.args, "remove_source_sentinel", False):
|
||||
nonnoise_spans.extend(doc[last_end:start_pos])
|
||||
else:
|
||||
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
|
||||
|
||||
|
||||
last_end = end_pos
|
||||
|
||||
|
||||
nonnoise_spans.extend(doc[last_end:])
|
||||
noise_spans = sum(noise_spans, [])
|
||||
|
||||
|
@ -276,10 +271,10 @@ class MLMLoader(BaseBatchGen):
|
|||
def _read_from_files(self, source_file, source_lang):
|
||||
# data = []
|
||||
file_path = os.path.join(self.data_dir, source_file)
|
||||
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print('| file {} not exists'.format(file_path), flush=True)
|
||||
return iter([]) # skip bad file
|
||||
return iter([]) # skip bad file
|
||||
|
||||
with open(file_path, 'r', encoding='utf8') as f:
|
||||
lines = f.read().strip().split('\n')
|
||||
|
@ -292,7 +287,7 @@ class MLMLoader(BaseBatchGen):
|
|||
yield doc
|
||||
doc = [self.dictionary.bos()]
|
||||
continue
|
||||
|
||||
|
||||
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
||||
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
||||
|
||||
|
@ -308,4 +303,4 @@ class MLMLoader(BaseBatchGen):
|
|||
# data.append(doc)
|
||||
yield doc
|
||||
|
||||
# return data
|
||||
# return data
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import numpy as np
|
||||
from random import Random
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Dict, Iterable, Optional
|
||||
import collections
|
||||
from infinibatch import iterators
|
||||
|
||||
|
||||
def apply_to_sample(f, sample):
|
||||
if hasattr(sample, "__len__") and len(sample) == 0:
|
||||
return {}
|
||||
|
@ -34,6 +33,7 @@ def apply_to_sample(f, sample):
|
|||
|
||||
return _apply(sample)
|
||||
|
||||
|
||||
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||
def __init__(self, iterable: Iterable):
|
||||
self._input_iterable = iterable
|
||||
|
@ -44,13 +44,16 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
|||
|
||||
def setstate(self, checkpoint: Optional[Dict]):
|
||||
self._iterator = iter(self._input_iterable)
|
||||
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
|
||||
self._num_items_yielded = iterators._advance_iterator(
|
||||
self._iterator,
|
||||
checkpoint['num_items_yielded']
|
||||
) if checkpoint is not None else 0
|
||||
|
||||
def __next__(self):
|
||||
item = next(self._iterator)
|
||||
self._num_items_yielded += 1
|
||||
return item
|
||||
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
@ -61,17 +64,17 @@ class WeightIterator(object):
|
|||
self.seed = seed
|
||||
self.control_index = list(range(len(weights)))
|
||||
self.setstate(None)
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def getstate(self):
|
||||
return {"random_state": self._random_state}
|
||||
|
||||
def setstate(self, checkpoint):
|
||||
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||||
self._random = None # this will trigger the lazy initialization in self.__next__
|
||||
|
||||
|
||||
def __next__(self):
|
||||
if self._random is None:
|
||||
self._random = Random(self.seed)
|
||||
|
@ -80,6 +83,6 @@ class WeightIterator(object):
|
|||
idx = self._random.choices(self.control_index, self.weights)[0]
|
||||
self._random_state = self._random.getstate()
|
||||
return idx
|
||||
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -10,16 +10,13 @@ import logging
|
|||
import os
|
||||
from argparse import Namespace
|
||||
import json
|
||||
from omegaconf import MISSING, II, OmegaConf
|
||||
from typing import Any
|
||||
from omegaconf import MISSING, II
|
||||
|
||||
import numpy as np
|
||||
from fairseq import utils
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from .data.mlm_loader import MLMLoader
|
||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
|
||||
import sentencepiece as spm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -27,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
|
||||
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainingConfig(FairseqDataclass):
|
||||
data: str = field(
|
||||
|
@ -163,11 +161,11 @@ class PLMTask(FairseqTask):
|
|||
'shuffle': True if split == 'train' else False,
|
||||
}
|
||||
self.datasets[split] = Namespace(**self.datasets[split])
|
||||
|
||||
|
||||
def dataset(self, split):
|
||||
if split not in self.datasets:
|
||||
raise KeyError("Dataset not loaded: " + split)
|
||||
|
||||
|
||||
return self.datasets[split]
|
||||
|
||||
def get_batch_iterator(
|
||||
|
@ -207,4 +205,4 @@ class PLMTask(FairseqTask):
|
|||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
return self.dictionary
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.train import cli_main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
cli_main()
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
|||
import torch.distributed as dist
|
||||
import math
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
||||
def grad_exists(p):
|
||||
|
@ -75,4 +76,4 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
|
|||
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
|
||||
for g in grads + expert_grads + sharded_grads + base_expert_grads:
|
||||
g.mul_(clip_coef)
|
||||
return total_norm
|
||||
return total_norm
|
||||
|
|
2
setup.py
2
setup.py
|
@ -25,4 +25,4 @@ setup(
|
|||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -23,6 +23,7 @@ testcases = [
|
|||
{"fsdp": True}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", testcases)
|
||||
def test_decoder(args):
|
||||
config = DecoderConfig(**args)
|
||||
|
|
|
@ -23,6 +23,7 @@ testcases = [
|
|||
{"fsdp": True}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", testcases)
|
||||
def test_encoder(args):
|
||||
config = EncoderConfig(**args)
|
||||
|
|
|
@ -25,13 +25,14 @@ testcases = [
|
|||
{"fsdp": True}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("args", testcases)
|
||||
def test_decoder(args):
|
||||
config = EncoderDecoderConfig(**args)
|
||||
model = EncoderDecoder(
|
||||
config,
|
||||
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
||||
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
||||
)
|
||||
|
@ -41,6 +42,6 @@ def test_decoder(args):
|
|||
|
||||
model(
|
||||
src_tokens=src_tokens,
|
||||
prev_output_tokens=prev_output_tokens,
|
||||
prev_output_tokens=prev_output_tokens,
|
||||
features_only=True,
|
||||
)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
|
@ -355,7 +355,8 @@ def top2gating(
|
|||
if has_tutel:
|
||||
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||
return l_aux, metadata, capacity, num_experts, \
|
||||
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
|
Loading…
Reference in New Issue
Block a user