diff --git a/examples/__init__.py b/examples/__init__.py index 6d707f2..3ae31e2 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/examples/fairseq/__init__.py b/examples/fairseq/__init__.py index 6d707f2..3ae31e2 100644 --- a/examples/fairseq/__init__.py +++ b/examples/fairseq/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/examples/fairseq/generate.py b/examples/fairseq/generate.py index 7c46266..e4f0662 100644 --- a/examples/fairseq/generate.py +++ b/examples/fairseq/generate.py @@ -1,10 +1,11 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] +# flake8: noqa import models import tasks from fairseq_cli.generate import cli_main if __name__ == "__main__": - cli_main() \ No newline at end of file + cli_main() diff --git a/examples/fairseq/interactive.py b/examples/fairseq/interactive.py index bcf6b64..30ec139 100644 --- a/examples/fairseq/interactive.py +++ b/examples/fairseq/interactive.py @@ -1,10 +1,11 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] +# flake8: noqa import models import tasks from fairseq_cli.interactive import cli_main if __name__ == "__main__": - cli_main() \ No newline at end of file + cli_main() diff --git a/examples/fairseq/models/__init__.py b/examples/fairseq/models/__init__.py index cbaa46b..7a0f259 100644 --- a/examples/fairseq/models/__init__.py +++ b/examples/fairseq/models/__init__.py @@ -33,4 +33,4 @@ for file in os.listdir(models_dir): ) group_args = parser.add_argument_group("Additional command-line arguments") MODEL_REGISTRY[model_name].add_args(group_args) - globals()[model_name + "_parser"] = parser \ No newline at end of file + globals()[model_name + "_parser"] = parser diff --git a/examples/fairseq/models/bert.py b/examples/fairseq/models/bert.py index d7c3202..f9a1eeb 100644 --- a/examples/fairseq/models/bert.py +++ b/examples/fairseq/models/bert.py @@ -1,24 +1,21 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] -import math import logging -from typing import Any, Dict, List, Optional +from typing import Optional from dataclasses import dataclass, field import torch import torch.nn as nn import torch.nn.functional as F from fairseq import utils -from fairseq.distributed import fsdp_wrap -from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture +from fairseq.models import BaseFairseqModel, register_model, register_model_architecture from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models.transformer import ( DEFAULT_MIN_PARAMS_TO_WRAP, Embedding ) from fairseq.modules import PositionalEmbedding from fairseq.models.squad import SQuADHead -from torch import Tensor from omegaconf import II from .machine_translation import MTEncoder as Encoder from torchscale.architecture.config import EncoderConfig @@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024 logger = logging.getLogger(__name__) + @dataclass class BertConfig(FairseqDataclass): activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( @@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass): moe_eval_capacity_token_fraction: Optional[float] = field( default=0.25, metadata={ - "help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." + "help": ( + "Default: 0.25, Fraction of tokens as capacity during validation, " + "if set to negative, use same as training. range: (0.0, 1.0]." + ) } ) moe_normalize_expert_grad: Optional[str] = field( @@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass): default=False, metadata={"help": "records all to all perf stats during distributed training"} ) dummy_a2a: Optional[bool] = field( - default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} + default=False, metadata={ + "help": "By passes all to all during distributed training by returning the input buffer as output"} ) moe_batch_prioritized_routing: Optional[bool] = field( default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} @@ -202,7 +204,7 @@ class BertConfig(FairseqDataclass): subln: Optional[bool] = field( default=False, ) - + @register_model("mlm", dataclass=BertConfig) class BertModel(BaseFairseqModel): @@ -245,9 +247,9 @@ class BertModel(BaseFairseqModel): config.override(args) encoder = Encoder( - config, - embed_tokens=embed_tokens, - embed_positions=embed_positions, + config, + embed_tokens=embed_tokens, + embed_positions=embed_positions, output_projection=lm_head, is_encoder_decoder=False, dictionary=task.dictionary, @@ -259,14 +261,14 @@ class BertModel(BaseFairseqModel): def build_embedding(cls, args, dictionary, embed_dim, path=None): embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) return embed_tokens - + @classmethod def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight): return LMHead(embed_dim, output_dim, activation_fn, weight) - + def output_layer(self, features, masked_tokens=None): return self.encoder.output_projection(features, masked_tokens=masked_tokens) - + def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): """Register a classification head.""" if name in self.classification_heads: @@ -286,12 +288,12 @@ class BertModel(BaseFairseqModel): self.args.pooler_activation_fn, self.args.pooler_dropout, ) - + def register_question_answering_head(self, name, num_classes=None): self.classification_heads[name] = SQuADHead( self.args.encoder_embed_dim, ) - + def upgrade_state_dict_named(self, state_dict, name): prefix = name + '.' if name != '' else '' @@ -342,15 +344,16 @@ class BertModel(BaseFairseqModel): if prefix + 'classification_heads.' + k not in state_dict: logger.info('Overwriting ' + prefix + 'classification_heads.' + k) state_dict[prefix + 'classification_heads.' + k] = v - + def forward( - self, - src_tokens=None, + self, + src_tokens=None, features_only=False, return_all_hiddens=False, - classification_head_name=None, + classification_head_name=None, masked_tokens=None, - **kwargs): + **kwargs + ): encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens) x, extra = encoder_out["encoder_out"], encoder_out x = x.transpose(0, 1) @@ -362,7 +365,7 @@ class BertModel(BaseFairseqModel): return x, extra - + class ClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -389,6 +392,7 @@ class ClassificationHead(nn.Module): x = self.out_proj(x) return x + class LMHead(nn.Module): """Head for masked language modeling.""" @@ -459,4 +463,4 @@ def base_unilm_architecture(args): args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: - args.checkpoint_activations = True \ No newline at end of file + args.checkpoint_activations = True diff --git a/examples/fairseq/models/language_modeling.py b/examples/fairseq/models/language_modeling.py index d754a8f..95fae3c 100644 --- a/examples/fairseq/models/language_modeling.py +++ b/examples/fairseq/models/language_modeling.py @@ -6,12 +6,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import math +import logging from dataclasses import dataclass, field from typing import Optional import torch -from fairseq import options, utils +from fairseq import utils from fairseq import distributed_utils from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.models import ( @@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig from omegaconf import II DEFAULT_MAX_TARGET_POSITIONS = 1024 -import logging logger = logging.getLogger(__name__) + @dataclass class LanguageConfig(FairseqDataclass): activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( @@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass): moe_eval_capacity_token_fraction: Optional[float] = field( default=0.25, metadata={ - "help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." + "help": ( + "Default: 0.25, Fraction of tokens as capacity during validation, " + "if set to negative, use same as training. range: (0.0, 1.0]." + ) } ) moe_normalize_expert_grad: Optional[str] = field( @@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass): default=False, metadata={"help": "records all to all perf stats during distributed training"} ) dummy_a2a: Optional[bool] = field( - default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} + default=False, metadata={ + "help": "By passes all to all during distributed training by returning the input buffer as output"} ) moe_batch_prioritized_routing: Optional[bool] = field( default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} @@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel): output_projection.weight = embed_tokens.weight else: output_projection = torch.nn.Linear( - decoder_embed_dim, len(task.dictionary), bias=False + args.decoder_embed_dim, len(task.dictionary), bias=False ) torch.nn.init.normal_( - output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5 + output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 ) if ( @@ -252,22 +256,23 @@ class LanguageModel(FairseqLanguageModel): and getattr(args, 'ddp_backend', None) != "fully_sharded" ) ): - assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" - + assert args.fp16_no_flatten_grads, \ + "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" + args.ddp_rank = distributed_utils.get_data_parallel_rank() config = DecoderConfig() config.override(args) decoder = LMDecoder( - config, - embed_tokens, + config, + embed_tokens, embed_positions, output_projection, is_encoder_decoder=False, dictionary=task.dictionary, ) - + return cls(args, decoder) @classmethod @@ -283,7 +288,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder): def max_positions(self): return self.embed_positions.max_positions - + def reorder_incremental_state_scripting( self, incremental_state, @@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder): result = incremental_state[module][key].index_select(0, new_order) incremental_state[module][key] = result + @register_model_architecture("lm", "lm_base") def base_lm_architecture(args): # backward compatibility for older model checkpoints @@ -357,4 +363,3 @@ def base_lm_architecture(args): args.offload_activations = getattr(args, "offload_activations", False) if args.offload_activations: args.checkpoint_activations = True - diff --git a/examples/fairseq/models/machine_translation.py b/examples/fairseq/models/machine_translation.py index f2e1b1a..8bdea52 100644 --- a/examples/fairseq/models/machine_translation.py +++ b/examples/fairseq/models/machine_translation.py @@ -6,33 +6,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import functools -import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch -import torch.nn as nn from fairseq import utils -from fairseq.distributed import utils as dist_utils, fsdp_wrap +from fairseq.distributed import utils as fsdp_wrap from fairseq import distributed_utils -from fairseq import checkpoint_utils from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, - FairseqIncrementalDecoder, register_model, register_model_architecture, ) from fairseq.models.transformer import Embedding -from fairseq.modules import ( - AdaptiveSoftmax, - FairseqDropout, - LayerDropModuleList, - LayerNorm, - PositionalEmbedding, - SinusoidalPositionalEmbedding, -) -from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules import PositionalEmbedding from torchscale.architecture.encoder import Encoder from torchscale.architecture.config import EncoderConfig, DecoderConfig from .language_modeling import LMDecoder as MTDecoder @@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel): help="Use FP32 computations in MoE top2 gating function") parser.add_argument('--moe-second-expert-policy', type=str, default='sampling', help="policy for second expert, options: all/sampling/random") - parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', - help="whether to normalize gate probs before or after dropping experts for capacity and randomization") + parser.add_argument( + '--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', + help=( + "whether to normalize gate probs before or after dropping experts " + "for capacity and randomization" + ) + ) parser.add_argument('--moe-expert-ffn-dim', type=int, default=0, help="MoE Expert FFN dimension") parser.add_argument('--moe-top1-expert', default=False, action='store_true', help="Use top1 gate instead of top2") - parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25, - help="Fraction of tokens as capacity during validation" + \ - "if set to negative, use same as training. range: (0.0, 1.0].") + parser.add_argument( + '--moe-eval-capacity-token-fraction', type=float, default=0.25, + help=( + "Fraction of tokens as capacity during validation" + "if set to negative, use same as training. range: (0.0, 1.0]." + ) + ) parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size', help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'") - parser.add_argument('--use-moe-pad-mask', default=False, action='store_true', help="Don't route padding tokens to any expert") parser.add_argument('--use-xmoe', default=False, action='store_true', @@ -207,7 +202,7 @@ class TranslationModel(FairseqEncoderDecoderModel): args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS if getattr(args, "max_target_positions", None) is None: args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS - + args.ddp_rank = distributed_utils.get_data_parallel_rank() src_dict, tgt_dict = task.source_dictionary, task.target_dictionary @@ -279,18 +274,18 @@ class TranslationModel(FairseqEncoderDecoderModel): encoder = cls.build_encoder( args, - encoder_embed_tokens, + encoder_embed_tokens, encoder_embed_positions, - src_dict, + src_dict, ) decoder = cls.build_decoder( - args, + args, decoder_embed_tokens, decoder_embed_positions, output_projection, tgt_dict, ) - + if not args.share_all_embeddings: min_params_to_wrap = getattr( args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP @@ -317,9 +312,9 @@ class TranslationModel(FairseqEncoderDecoderModel): config.override(args) return MTEncoder( - config, - embed_tokens, - embed_positions, + config, + embed_tokens, + embed_positions, is_encoder_decoder=True, dictionary=dictionary, ) @@ -330,8 +325,8 @@ class TranslationModel(FairseqEncoderDecoderModel): config.override(args) return MTDecoder( - config, - embed_tokens, + config, + embed_tokens, embed_positions, output_projection, is_encoder_decoder=True, @@ -348,7 +343,7 @@ class TranslationModel(FairseqEncoderDecoderModel): **kwargs ): encoder_out = self.encoder( - src_tokens, + src_tokens, return_all_hiddens=return_all_hiddens ) decoder_out = self.decoder( @@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder): def max_positions(self): return self.embed_positions.max_positions + @register_model_architecture("mt", "mt_base") def base_architecture(args): args.encoder_embed_path = getattr(args, "encoder_embed_path", None) diff --git a/examples/fairseq/tasks/__init__.py b/examples/fairseq/tasks/__init__.py index 86d3d37..dfbe5e0 100644 --- a/examples/fairseq/tasks/__init__.py +++ b/examples/fairseq/tasks/__init__.py @@ -32,4 +32,4 @@ for file in os.listdir(tasks_dir): # fmt: on group_args = parser.add_argument_group("Additional command-line arguments") TASK_REGISTRY[task_name].add_args(group_args) - globals()[task_name + "_parser"] = parser \ No newline at end of file + globals()[task_name + "_parser"] = parser diff --git a/examples/fairseq/tasks/data/__init__.py b/examples/fairseq/tasks/data/__init__.py index 6d707f2..3ae31e2 100644 --- a/examples/fairseq/tasks/data/__init__.py +++ b/examples/fairseq/tasks/data/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/examples/fairseq/tasks/data/basic_loader.py b/examples/fairseq/tasks/data/basic_loader.py index ca9a3b5..adce2f8 100644 --- a/examples/fairseq/tasks/data/basic_loader.py +++ b/examples/fairseq/tasks/data/basic_loader.py @@ -1,14 +1,11 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] -import math -import re -import sys -import time import torch from infinibatch.iterators import CheckpointableIterator from . import utils + class BaseBatchGen(CheckpointableIterator): """ This is a base class for batch generators that use infinibatch @@ -26,7 +23,7 @@ class BaseBatchGen(CheckpointableIterator): Build infinibatch iterator and assign to self._iter """ raise NotImplementedError() - + def _move_to_tensor(self, batch): def to_tensor(x): @@ -47,16 +44,16 @@ class BaseBatchGen(CheckpointableIterator): def __next__(self): return next(self._iter) - + def setstate(self, value): self._iter.setstate(value) - + def getstate(self): return self._iter.getstate() - + def close(self): self._iter.close() - + def __len__(self) -> int: return 819200000 @@ -78,4 +75,4 @@ class BaseBatchGen(CheckpointableIterator): @property def first_batch(self): - return "DUMMY" \ No newline at end of file + return "DUMMY" diff --git a/examples/fairseq/tasks/data/mlm_loader.py b/examples/fairseq/tasks/data/mlm_loader.py index 4d8c712..6805347 100644 --- a/examples/fairseq/tasks/data/mlm_loader.py +++ b/examples/fairseq/tasks/data/mlm_loader.py @@ -1,13 +1,8 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] -import glob import os -import torch import numpy as np -import time -import json -import random import itertools import copy @@ -55,15 +50,15 @@ class MLMLoader(BaseBatchGen): self.batch_read_ahead = args.batch_read_ahead self._build_iter() - + def _build_iter(self): tokenized_lines = self._multilingual_tokenize() self.padded_batches = self._batchify(tokenized_lines) - + prefetch_batches = iterators.PrefetchIterator( - self.padded_batches, - buffer_size=10000, - buffer_in_main_process=True, + self.padded_batches, + buffer_size=10000, + buffer_in_main_process=True, log_empty_buffer_warning=True and self.shard_id == 0, ) @@ -85,14 +80,14 @@ class MLMLoader(BaseBatchGen): weights.append(float(data['weight'])) else: weights.append(int(data['count'])) - + if len(multilingual_iters) == 1: return multilingual_iters[0] sampling_iterator = WeightIterator(weights) control_iterator = NativeCheckpointableIterator(sampling_iterator) tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters) - + return tokenized_lines def _tokenize(self, data): @@ -109,7 +104,7 @@ class MLMLoader(BaseBatchGen): dataset = list( zip( data['source'], - itertools.repeat(data['source_lang']), + itertools.repeat(data['source_lang']), ) ) @@ -117,27 +112,26 @@ class MLMLoader(BaseBatchGen): chunk_files = \ iterators.InfinitePermutationSourceIterator( dataset, - seed=self.seed, - shuffle=self.shuffle, - num_instances=self.num_shards, + seed=self.seed, + shuffle=self.shuffle, + num_instances=self.num_shards, instance_rank=self.shard_id, ) else: chunk_files = \ iterators.ChunkedSourceIterator( dataset, - num_instances=self.num_shards, + num_instances=self.num_shards, instance_rank=self.shard_id, ) - + tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files)) tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed) - + return tokenized_lines - def _batchify(self, lines): - + if self.max_sentences is not None: if self.batch_read_ahead > 0: lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed) @@ -145,14 +139,15 @@ class MLMLoader(BaseBatchGen): else: def dynamic_batch_size(sample): lengths = [len(x) for x in sample] - batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple + batch_size = self.max_tokens // max(lengths) + batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple return max(1, batch_size) - + batches = iterators.BucketedReadaheadBatchIterator( lines, - read_ahead=self.batch_read_ahead, - key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, - batch_size=dynamic_batch_size, + read_ahead=self.batch_read_ahead, + key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, + batch_size=dynamic_batch_size, shuffle=self.shuffle, seed=self.seed, ) @@ -166,15 +161,15 @@ class MLMLoader(BaseBatchGen): s2s_target_max_length = max([len(x[3]) for x in batch]) mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32, - fill_value=self.dictionary.pad()) + fill_value=self.dictionary.pad()) mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32, fill_value=self.dictionary.pad()) s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32, - fill_value=self.dictionary.pad()) + fill_value=self.dictionary.pad()) s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, fill_value=self.dictionary.pad()) s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, - fill_value=self.dictionary.pad()) + fill_value=self.dictionary.pad()) for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch): mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids @@ -182,7 +177,7 @@ class MLMLoader(BaseBatchGen): s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:] s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1] - + ret_batch = { 'net_input': { 'src_tokens': mlm_source_ids.astype(np.int64), @@ -199,16 +194,16 @@ class MLMLoader(BaseBatchGen): ) return padded_batches - + def _prepare(self, _random, doc): nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc) nonnoise_spans, noise_spans = self._span_corruption(_random, doc) return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans - + def _mask_lm(self, _random, doc): def mask_tokens(): - return f"" - + return "" + length = len(doc) mask_tokens_num = int(length * self.args.mask_prob) mask_tokens_num = min(max(mask_tokens_num, 1), length - 1) @@ -222,11 +217,11 @@ class MLMLoader(BaseBatchGen): # masked_tokens.append(nonmasked_tokens[position]) masked_tokens[position] = nonmasked_tokens[position] nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()] - + return nonmasked_tokens, masked_tokens def _span_corruption(self, _random, doc): - + def mask_tokens(i): return f"" @@ -244,7 +239,7 @@ class MLMLoader(BaseBatchGen): _random.shuffle(possible_split_positions) noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1]) noise_split_positions = [0] + noise_split_positions + [noise_tokens_num] - + possible_insert_positions = list(range(nonnoise_tokens_num)) _random.shuffle(possible_insert_positions) noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num]) @@ -260,14 +255,14 @@ class MLMLoader(BaseBatchGen): noise_spans.append(doc[start_pos:end_pos]) else: noise_spans.append([mask_id] + doc[start_pos:end_pos]) - + if getattr(self.args, "remove_source_sentinel", False): nonnoise_spans.extend(doc[last_end:start_pos]) else: nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id]) - + last_end = end_pos - + nonnoise_spans.extend(doc[last_end:]) noise_spans = sum(noise_spans, []) @@ -276,10 +271,10 @@ class MLMLoader(BaseBatchGen): def _read_from_files(self, source_file, source_lang): # data = [] file_path = os.path.join(self.data_dir, source_file) - + if not os.path.exists(file_path): print('| file {} not exists'.format(file_path), flush=True) - return iter([]) # skip bad file + return iter([]) # skip bad file with open(file_path, 'r', encoding='utf8') as f: lines = f.read().strip().split('\n') @@ -292,7 +287,7 @@ class MLMLoader(BaseBatchGen): yield doc doc = [self.dictionary.bos()] continue - + tokenized_line = self.tokenizer.EncodeAsPieces(line) tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index] @@ -308,4 +303,4 @@ class MLMLoader(BaseBatchGen): # data.append(doc) yield doc - # return data \ No newline at end of file + # return data diff --git a/examples/fairseq/tasks/data/utils.py b/examples/fairseq/tasks/data/utils.py index eb2310e..afbcf75 100644 --- a/examples/fairseq/tasks/data/utils.py +++ b/examples/fairseq/tasks/data/utils.py @@ -1,14 +1,13 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] -import os -import gzip import numpy as np from random import Random -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterable, Optional import collections from infinibatch import iterators + def apply_to_sample(f, sample): if hasattr(sample, "__len__") and len(sample) == 0: return {} @@ -34,6 +33,7 @@ def apply_to_sample(f, sample): return _apply(sample) + class NativeCheckpointableIterator(iterators.CheckpointableIterator): def __init__(self, iterable: Iterable): self._input_iterable = iterable @@ -44,13 +44,16 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator): def setstate(self, checkpoint: Optional[Dict]): self._iterator = iter(self._input_iterable) - self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0 + self._num_items_yielded = iterators._advance_iterator( + self._iterator, + checkpoint['num_items_yielded'] + ) if checkpoint is not None else 0 def __next__(self): item = next(self._iterator) self._num_items_yielded += 1 return item - + def close(self): pass @@ -61,17 +64,17 @@ class WeightIterator(object): self.seed = seed self.control_index = list(range(len(weights))) self.setstate(None) - + def __iter__(self): return self - + def getstate(self): return {"random_state": self._random_state} def setstate(self, checkpoint): self._random_state = checkpoint["random_state"] if checkpoint else None self._random = None # this will trigger the lazy initialization in self.__next__ - + def __next__(self): if self._random is None: self._random = Random(self.seed) @@ -80,6 +83,6 @@ class WeightIterator(object): idx = self._random.choices(self.control_index, self.weights)[0] self._random_state = self._random.getstate() return idx - + def close(self): - pass \ No newline at end of file + pass diff --git a/examples/fairseq/tasks/pretraining.py b/examples/fairseq/tasks/pretraining.py index d935b91..bcb6fd1 100644 --- a/examples/fairseq/tasks/pretraining.py +++ b/examples/fairseq/tasks/pretraining.py @@ -10,16 +10,13 @@ import logging import os from argparse import Namespace import json -from omegaconf import MISSING, II, OmegaConf -from typing import Any +from omegaconf import MISSING, II -import numpy as np from fairseq import utils from fairseq.data import Dictionary from fairseq.tasks import FairseqTask, register_task from .data.mlm_loader import MLMLoader from fairseq.dataclass import FairseqDataclass, ChoiceEnum -from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE import sentencepiece as spm logger = logging.getLogger(__name__) @@ -27,6 +24,7 @@ logger = logging.getLogger(__name__) SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + @dataclass class PretrainingConfig(FairseqDataclass): data: str = field( @@ -163,11 +161,11 @@ class PLMTask(FairseqTask): 'shuffle': True if split == 'train' else False, } self.datasets[split] = Namespace(**self.datasets[split]) - + def dataset(self, split): if split not in self.datasets: raise KeyError("Dataset not loaded: " + split) - + return self.datasets[split] def get_batch_iterator( @@ -207,4 +205,4 @@ class PLMTask(FairseqTask): @property def target_dictionary(self): - return self.dictionary \ No newline at end of file + return self.dictionary diff --git a/examples/fairseq/train.py b/examples/fairseq/train.py index 2c2b120..4fa210d 100644 --- a/examples/fairseq/train.py +++ b/examples/fairseq/train.py @@ -1,11 +1,11 @@ # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] +# flake8: noqa import models import tasks from fairseq_cli.train import cli_main - if __name__ == "__main__": - cli_main() \ No newline at end of file + cli_main() diff --git a/examples/fairseq/utils/__init__.py b/examples/fairseq/utils/__init__.py index 6d707f2..3ae31e2 100644 --- a/examples/fairseq/utils/__init__.py +++ b/examples/fairseq/utils/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/examples/fairseq/utils/sparse_clip.py b/examples/fairseq/utils/sparse_clip.py index 6f244dc..acaa4d2 100644 --- a/examples/fairseq/utils/sparse_clip.py +++ b/examples/fairseq/utils/sparse_clip.py @@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm import torch.distributed as dist import math + @torch.no_grad() def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor: def grad_exists(p): @@ -75,4 +76,4 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) for g in grads + expert_grads + sharded_grads + base_expert_grads: g.mul_(clip_coef) - return total_norm \ No newline at end of file + return total_norm diff --git a/setup.py b/setup.py index 5cf8853..58ecd9c 100644 --- a/setup.py +++ b/setup.py @@ -25,4 +25,4 @@ setup( classifiers=[ 'Programming Language :: Python :: 3', ], -) \ No newline at end of file +) diff --git a/tests/__init__.py b/tests/__init__.py index 6d707f2..3ae31e2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/tests/test_decoder.py b/tests/test_decoder.py index d95080c..58d6987 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -23,6 +23,7 @@ testcases = [ {"fsdp": True} ] + @pytest.mark.parametrize("args", testcases) def test_decoder(args): config = DecoderConfig(**args) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index d179956..838cc90 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -23,6 +23,7 @@ testcases = [ {"fsdp": True} ] + @pytest.mark.parametrize("args", testcases) def test_encoder(args): config = EncoderConfig(**args) diff --git a/tests/test_encoder_decoder.py b/tests/test_encoder_decoder.py index 19672c0..4158309 100644 --- a/tests/test_encoder_decoder.py +++ b/tests/test_encoder_decoder.py @@ -25,13 +25,14 @@ testcases = [ {"fsdp": True} ] + @pytest.mark.parametrize("args", testcases) def test_decoder(args): config = EncoderDecoderConfig(**args) model = EncoderDecoder( config, - encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), - decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), + encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), + decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim), decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim), ) @@ -41,6 +42,6 @@ def test_decoder(args): model( src_tokens=src_tokens, - prev_output_tokens=prev_output_tokens, + prev_output_tokens=prev_output_tokens, features_only=True, ) diff --git a/torchscale/__init__.py b/torchscale/__init__.py index 6d707f2..3ae31e2 100644 --- a/torchscale/__init__.py +++ b/torchscale/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/torchscale/architecture/__init__.py b/torchscale/architecture/__init__.py index 6d707f2..3ae31e2 100644 --- a/torchscale/architecture/__init__.py +++ b/torchscale/architecture/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/torchscale/component/__init__.py b/torchscale/component/__init__.py index 6d707f2..3ae31e2 100644 --- a/torchscale/component/__init__.py +++ b/torchscale/component/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/torchscale/component/xmoe/__init__.py b/torchscale/component/xmoe/__init__.py index 6d707f2..3ae31e2 100644 --- a/torchscale/component/xmoe/__init__.py +++ b/torchscale/component/xmoe/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details] diff --git a/torchscale/component/xmoe/routing.py b/torchscale/component/xmoe/routing.py index c882e83..7abce6e 100644 --- a/torchscale/component/xmoe/routing.py +++ b/torchscale/component/xmoe/routing.py @@ -355,7 +355,8 @@ def top2gating( if has_tutel: locations1_s = torch.sum(locations1 * mask1_, dim=1) locations2_s = torch.sum(locations2 * mask2_, dim=1) - return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s] + return l_aux, metadata, capacity, num_experts, \ + [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s] # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) diff --git a/torchscale/model/__init__.py b/torchscale/model/__init__.py index 6d707f2..3ae31e2 100644 --- a/torchscale/model/__init__.py +++ b/torchscale/model/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) 2022 Microsoft -# Licensed under The MIT License [see LICENSE for details] \ No newline at end of file +# Licensed under The MIT License [see LICENSE for details]