flake8 lint checks

This commit is contained in:
shumingma 2022-11-26 08:10:15 -08:00
parent 4714557e89
commit 994e4665a2
28 changed files with 168 additions and 163 deletions

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -1,10 +1,11 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models
import tasks
from fairseq_cli.generate import cli_main
if __name__ == "__main__":
cli_main()
cli_main()

View File

@ -1,10 +1,11 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models
import tasks
from fairseq_cli.interactive import cli_main
if __name__ == "__main__":
cli_main()
cli_main()

View File

@ -33,4 +33,4 @@ for file in os.listdir(models_dir):
)
group_args = parser.add_argument_group("Additional command-line arguments")
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser
globals()[model_name + "_parser"] = parser

View File

@ -1,24 +1,21 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import logging
from typing import Any, Dict, List, Optional
from typing import Optional
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
)
from fairseq.modules import PositionalEmbedding
from fairseq.models.squad import SQuADHead
from torch import Tensor
from omegaconf import II
from .machine_translation import MTEncoder as Encoder
from torchscale.architecture.config import EncoderConfig
@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
logger = logging.getLogger(__name__)
@dataclass
class BertConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass):
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
}
)
moe_normalize_expert_grad: Optional[str] = field(
@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass):
default=False, metadata={"help": "records all to all perf stats during distributed training"}
)
dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
@ -202,7 +204,7 @@ class BertConfig(FairseqDataclass):
subln: Optional[bool] = field(
default=False,
)
@register_model("mlm", dataclass=BertConfig)
class BertModel(BaseFairseqModel):
@ -245,9 +247,9 @@ class BertModel(BaseFairseqModel):
config.override(args)
encoder = Encoder(
config,
embed_tokens=embed_tokens,
embed_positions=embed_positions,
config,
embed_tokens=embed_tokens,
embed_positions=embed_positions,
output_projection=lm_head,
is_encoder_decoder=False,
dictionary=task.dictionary,
@ -259,14 +261,14 @@ class BertModel(BaseFairseqModel):
def build_embedding(cls, args, dictionary, embed_dim, path=None):
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
return embed_tokens
@classmethod
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
return LMHead(embed_dim, output_dim, activation_fn, weight)
def output_layer(self, features, masked_tokens=None):
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
if name in self.classification_heads:
@ -286,12 +288,12 @@ class BertModel(BaseFairseqModel):
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
def register_question_answering_head(self, name, num_classes=None):
self.classification_heads[name] = SQuADHead(
self.args.encoder_embed_dim,
)
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
@ -342,15 +344,16 @@ class BertModel(BaseFairseqModel):
if prefix + 'classification_heads.' + k not in state_dict:
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
def forward(
self,
src_tokens=None,
self,
src_tokens=None,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
classification_head_name=None,
masked_tokens=None,
**kwargs):
**kwargs
):
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
x, extra = encoder_out["encoder_out"], encoder_out
x = x.transpose(0, 1)
@ -362,7 +365,7 @@ class BertModel(BaseFairseqModel):
return x, extra
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
@ -389,6 +392,7 @@ class ClassificationHead(nn.Module):
x = self.out_proj(x)
return x
class LMHead(nn.Module):
"""Head for masked language modeling."""
@ -459,4 +463,4 @@ def base_unilm_architecture(args):
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.checkpoint_activations = True

View File

@ -6,12 +6,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import options, utils
from fairseq import utils
from fairseq import distributed_utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig
from omegaconf import II
DEFAULT_MAX_TARGET_POSITIONS = 1024
import logging
logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass):
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]."
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
}
)
moe_normalize_expert_grad: Optional[str] = field(
@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass):
default=False, metadata={"help": "records all to all perf stats during distributed training"}
)
dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"}
default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel):
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
decoder_embed_dim, len(task.dictionary), bias=False
args.decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
)
if (
@ -252,22 +256,23 @@ class LanguageModel(FairseqLanguageModel):
and getattr(args, 'ddp_backend', None) != "fully_sharded"
)
):
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
assert args.fp16_no_flatten_grads, \
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = DecoderConfig()
config.override(args)
decoder = LMDecoder(
config,
embed_tokens,
config,
embed_tokens,
embed_positions,
output_projection,
is_encoder_decoder=False,
dictionary=task.dictionary,
)
return cls(args, decoder)
@classmethod
@ -283,7 +288,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
def max_positions(self):
return self.embed_positions.max_positions
def reorder_incremental_state_scripting(
self,
incremental_state,
@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
@register_model_architecture("lm", "lm_base")
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
@ -357,4 +363,3 @@ def base_lm_architecture(args):
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True

View File

@ -6,33 +6,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import math
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.distributed import utils as dist_utils, fsdp_wrap
from fairseq.distributed import utils as fsdp_wrap
from fairseq import distributed_utils
from fairseq import checkpoint_utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding
from fairseq.modules import (
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules import PositionalEmbedding
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder
@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel):
help="Use FP32 computations in MoE top2 gating function")
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
help="policy for second expert, options: all/sampling/random")
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help="whether to normalize gate probs before or after dropping experts for capacity and randomization")
parser.add_argument(
'--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help=(
"whether to normalize gate probs before or after dropping experts "
"for capacity and randomization"
)
)
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
help="MoE Expert FFN dimension")
parser.add_argument('--moe-top1-expert', default=False, action='store_true',
help="Use top1 gate instead of top2")
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25,
help="Fraction of tokens as capacity during validation" + \
"if set to negative, use same as training. range: (0.0, 1.0].")
parser.add_argument(
'--moe-eval-capacity-token-fraction', type=float, default=0.25,
help=(
"Fraction of tokens as capacity during validation"
"if set to negative, use same as training. range: (0.0, 1.0]."
)
)
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
help="Don't route padding tokens to any expert")
parser.add_argument('--use-xmoe', default=False, action='store_true',
@ -207,7 +202,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
args.ddp_rank = distributed_utils.get_data_parallel_rank()
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
@ -279,18 +274,18 @@ class TranslationModel(FairseqEncoderDecoderModel):
encoder = cls.build_encoder(
args,
encoder_embed_tokens,
encoder_embed_tokens,
encoder_embed_positions,
src_dict,
src_dict,
)
decoder = cls.build_decoder(
args,
args,
decoder_embed_tokens,
decoder_embed_positions,
output_projection,
tgt_dict,
)
if not args.share_all_embeddings:
min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
@ -317,9 +312,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
config.override(args)
return MTEncoder(
config,
embed_tokens,
embed_positions,
config,
embed_tokens,
embed_positions,
is_encoder_decoder=True,
dictionary=dictionary,
)
@ -330,8 +325,8 @@ class TranslationModel(FairseqEncoderDecoderModel):
config.override(args)
return MTDecoder(
config,
embed_tokens,
config,
embed_tokens,
embed_positions,
output_projection,
is_encoder_decoder=True,
@ -348,7 +343,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
**kwargs
):
encoder_out = self.encoder(
src_tokens,
src_tokens,
return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder(
@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder):
def max_positions(self):
return self.embed_positions.max_positions
@register_model_architecture("mt", "mt_base")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)

View File

@ -32,4 +32,4 @@ for file in os.listdir(tasks_dir):
# fmt: on
group_args = parser.add_argument_group("Additional command-line arguments")
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser
globals()[task_name + "_parser"] = parser

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -1,14 +1,11 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import re
import sys
import time
import torch
from infinibatch.iterators import CheckpointableIterator
from . import utils
class BaseBatchGen(CheckpointableIterator):
"""
This is a base class for batch generators that use infinibatch
@ -26,7 +23,7 @@ class BaseBatchGen(CheckpointableIterator):
Build infinibatch iterator and assign to self._iter
"""
raise NotImplementedError()
def _move_to_tensor(self, batch):
def to_tensor(x):
@ -47,16 +44,16 @@ class BaseBatchGen(CheckpointableIterator):
def __next__(self):
return next(self._iter)
def setstate(self, value):
self._iter.setstate(value)
def getstate(self):
return self._iter.getstate()
def close(self):
self._iter.close()
def __len__(self) -> int:
return 819200000
@ -78,4 +75,4 @@ class BaseBatchGen(CheckpointableIterator):
@property
def first_batch(self):
return "DUMMY"
return "DUMMY"

View File

@ -1,13 +1,8 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import glob
import os
import torch
import numpy as np
import time
import json
import random
import itertools
import copy
@ -55,15 +50,15 @@ class MLMLoader(BaseBatchGen):
self.batch_read_ahead = args.batch_read_ahead
self._build_iter()
def _build_iter(self):
tokenized_lines = self._multilingual_tokenize()
self.padded_batches = self._batchify(tokenized_lines)
prefetch_batches = iterators.PrefetchIterator(
self.padded_batches,
buffer_size=10000,
buffer_in_main_process=True,
self.padded_batches,
buffer_size=10000,
buffer_in_main_process=True,
log_empty_buffer_warning=True and self.shard_id == 0,
)
@ -85,14 +80,14 @@ class MLMLoader(BaseBatchGen):
weights.append(float(data['weight']))
else:
weights.append(int(data['count']))
if len(multilingual_iters) == 1:
return multilingual_iters[0]
sampling_iterator = WeightIterator(weights)
control_iterator = NativeCheckpointableIterator(sampling_iterator)
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
return tokenized_lines
def _tokenize(self, data):
@ -109,7 +104,7 @@ class MLMLoader(BaseBatchGen):
dataset = list(
zip(
data['source'],
itertools.repeat(data['source_lang']),
itertools.repeat(data['source_lang']),
)
)
@ -117,27 +112,26 @@ class MLMLoader(BaseBatchGen):
chunk_files = \
iterators.InfinitePermutationSourceIterator(
dataset,
seed=self.seed,
shuffle=self.shuffle,
num_instances=self.num_shards,
seed=self.seed,
shuffle=self.shuffle,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
else:
chunk_files = \
iterators.ChunkedSourceIterator(
dataset,
num_instances=self.num_shards,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
return tokenized_lines
def _batchify(self, lines):
if self.max_sentences is not None:
if self.batch_read_ahead > 0:
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
@ -145,14 +139,15 @@ class MLMLoader(BaseBatchGen):
else:
def dynamic_batch_size(sample):
lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple
batch_size = self.max_tokens // max(lengths)
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator(
lines,
read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size,
read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size,
shuffle=self.shuffle,
seed=self.seed,
)
@ -166,15 +161,15 @@ class MLMLoader(BaseBatchGen):
s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
fill_value=self.dictionary.pad())
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
fill_value=self.dictionary.pad())
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
fill_value=self.dictionary.pad())
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
@ -182,7 +177,7 @@ class MLMLoader(BaseBatchGen):
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
ret_batch = {
'net_input': {
'src_tokens': mlm_source_ids.astype(np.int64),
@ -199,16 +194,16 @@ class MLMLoader(BaseBatchGen):
)
return padded_batches
def _prepare(self, _random, doc):
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
def _mask_lm(self, _random, doc):
def mask_tokens():
return f"<mask>"
return "<mask>"
length = len(doc)
mask_tokens_num = int(length * self.args.mask_prob)
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
@ -222,11 +217,11 @@ class MLMLoader(BaseBatchGen):
# masked_tokens.append(nonmasked_tokens[position])
masked_tokens[position] = nonmasked_tokens[position]
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
return nonmasked_tokens, masked_tokens
def _span_corruption(self, _random, doc):
def mask_tokens(i):
return f"<mask_{i}>"
@ -244,7 +239,7 @@ class MLMLoader(BaseBatchGen):
_random.shuffle(possible_split_positions)
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
possible_insert_positions = list(range(nonnoise_tokens_num))
_random.shuffle(possible_insert_positions)
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
@ -260,14 +255,14 @@ class MLMLoader(BaseBatchGen):
noise_spans.append(doc[start_pos:end_pos])
else:
noise_spans.append([mask_id] + doc[start_pos:end_pos])
if getattr(self.args, "remove_source_sentinel", False):
nonnoise_spans.extend(doc[last_end:start_pos])
else:
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
last_end = end_pos
nonnoise_spans.extend(doc[last_end:])
noise_spans = sum(noise_spans, [])
@ -276,10 +271,10 @@ class MLMLoader(BaseBatchGen):
def _read_from_files(self, source_file, source_lang):
# data = []
file_path = os.path.join(self.data_dir, source_file)
if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True)
return iter([]) # skip bad file
return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f:
lines = f.read().strip().split('\n')
@ -292,7 +287,7 @@ class MLMLoader(BaseBatchGen):
yield doc
doc = [self.dictionary.bos()]
continue
tokenized_line = self.tokenizer.EncodeAsPieces(line)
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
@ -308,4 +303,4 @@ class MLMLoader(BaseBatchGen):
# data.append(doc)
yield doc
# return data
# return data

View File

@ -1,14 +1,13 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import os
import gzip
import numpy as np
from random import Random
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Dict, Iterable, Optional
import collections
from infinibatch import iterators
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
@ -34,6 +33,7 @@ def apply_to_sample(f, sample):
return _apply(sample)
class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def __init__(self, iterable: Iterable):
self._input_iterable = iterable
@ -44,13 +44,16 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0
self._num_items_yielded = iterators._advance_iterator(
self._iterator,
checkpoint['num_items_yielded']
) if checkpoint is not None else 0
def __next__(self):
item = next(self._iterator)
self._num_items_yielded += 1
return item
def close(self):
pass
@ -61,17 +64,17 @@ class WeightIterator(object):
self.seed = seed
self.control_index = list(range(len(weights)))
self.setstate(None)
def __iter__(self):
return self
def getstate(self):
return {"random_state": self._random_state}
def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None
self._random = None # this will trigger the lazy initialization in self.__next__
def __next__(self):
if self._random is None:
self._random = Random(self.seed)
@ -80,6 +83,6 @@ class WeightIterator(object):
idx = self._random.choices(self.control_index, self.weights)[0]
self._random_state = self._random.getstate()
return idx
def close(self):
pass
pass

View File

@ -10,16 +10,13 @@ import logging
import os
from argparse import Namespace
import json
from omegaconf import MISSING, II, OmegaConf
from typing import Any
from omegaconf import MISSING, II
import numpy as np
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task
from .data.mlm_loader import MLMLoader
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
import sentencepiece as spm
logger = logging.getLogger(__name__)
@ -27,6 +24,7 @@ logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass
class PretrainingConfig(FairseqDataclass):
data: str = field(
@ -163,11 +161,11 @@ class PLMTask(FairseqTask):
'shuffle': True if split == 'train' else False,
}
self.datasets[split] = Namespace(**self.datasets[split])
def dataset(self, split):
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
return self.datasets[split]
def get_batch_iterator(
@ -207,4 +205,4 @@ class PLMTask(FairseqTask):
@property
def target_dictionary(self):
return self.dictionary
return self.dictionary

View File

@ -1,11 +1,11 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models
import tasks
from fairseq_cli.train import cli_main
if __name__ == "__main__":
cli_main()
cli_main()

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
import torch.distributed as dist
import math
@torch.no_grad()
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p):
@ -75,4 +76,4 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
for g in grads + expert_grads + sharded_grads + base_expert_grads:
g.mul_(clip_coef)
return total_norm
return total_norm

View File

@ -25,4 +25,4 @@ setup(
classifiers=[
'Programming Language :: Python :: 3',
],
)
)

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -23,6 +23,7 @@ testcases = [
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_decoder(args):
config = DecoderConfig(**args)

View File

@ -23,6 +23,7 @@ testcases = [
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_encoder(args):
config = EncoderConfig(**args)

View File

@ -25,13 +25,14 @@ testcases = [
{"fsdp": True}
]
@pytest.mark.parametrize("args", testcases)
def test_decoder(args):
config = EncoderDecoderConfig(**args)
model = EncoderDecoder(
config,
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
)
@ -41,6 +42,6 @@ def test_decoder(args):
model(
src_tokens=src_tokens,
prev_output_tokens=prev_output_tokens,
prev_output_tokens=prev_output_tokens,
features_only=True,
)

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]

View File

@ -355,7 +355,8 @@ def top2gating(
if has_tutel:
locations1_s = torch.sum(locations1 * mask1_, dim=1)
locations2_s = torch.sum(locations2 * mask2_, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
return l_aux, metadata, capacity, num_experts, \
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]