flake8 lint checks

This commit is contained in:
shumingma 2022-11-26 08:10:15 -08:00
parent 4714557e89
commit 994e4665a2
28 changed files with 168 additions and 163 deletions

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -1,10 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.generate import cli_main from fairseq_cli.generate import cli_main
if __name__ == "__main__": if __name__ == "__main__":
cli_main() cli_main()

View File

@ -1,10 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.interactive import cli_main from fairseq_cli.interactive import cli_main
if __name__ == "__main__": if __name__ == "__main__":
cli_main() cli_main()

View File

@ -33,4 +33,4 @@ for file in os.listdir(models_dir):
) )
group_args = parser.add_argument_group("Additional command-line arguments") group_args = parser.add_argument_group("Additional command-line arguments")
MODEL_REGISTRY[model_name].add_args(group_args) MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + "_parser"] = parser globals()[model_name + "_parser"] = parser

View File

@ -1,24 +1,21 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math
import logging import logging
from typing import Any, Dict, List, Optional from typing import Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.distributed import fsdp_wrap from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.models import BaseFairseqModel, FairseqIncrementalDecoder, register_model, register_model_architecture
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models.transformer import ( from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
) )
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from fairseq.models.squad import SQuADHead from fairseq.models.squad import SQuADHead
from torch import Tensor
from omegaconf import II from omegaconf import II
from .machine_translation import MTEncoder as Encoder from .machine_translation import MTEncoder as Encoder
from torchscale.architecture.config import EncoderConfig from torchscale.architecture.config import EncoderConfig
@ -28,6 +25,7 @@ DEFAULT_MAX_SOURCE_POSITIONS = 1024
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class BertConfig(FairseqDataclass): class BertConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -177,7 +175,10 @@ class BertConfig(FairseqDataclass):
moe_eval_capacity_token_fraction: Optional[float] = field( moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25, default=0.25,
metadata={ metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." "help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
} }
) )
moe_normalize_expert_grad: Optional[str] = field( moe_normalize_expert_grad: Optional[str] = field(
@ -190,7 +191,8 @@ class BertConfig(FairseqDataclass):
default=False, metadata={"help": "records all to all perf stats during distributed training"} default=False, metadata={"help": "records all to all perf stats during distributed training"}
) )
dummy_a2a: Optional[bool] = field( dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
) )
moe_batch_prioritized_routing: Optional[bool] = field( moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
@ -202,7 +204,7 @@ class BertConfig(FairseqDataclass):
subln: Optional[bool] = field( subln: Optional[bool] = field(
default=False, default=False,
) )
@register_model("mlm", dataclass=BertConfig) @register_model("mlm", dataclass=BertConfig)
class BertModel(BaseFairseqModel): class BertModel(BaseFairseqModel):
@ -245,9 +247,9 @@ class BertModel(BaseFairseqModel):
config.override(args) config.override(args)
encoder = Encoder( encoder = Encoder(
config, config,
embed_tokens=embed_tokens, embed_tokens=embed_tokens,
embed_positions=embed_positions, embed_positions=embed_positions,
output_projection=lm_head, output_projection=lm_head,
is_encoder_decoder=False, is_encoder_decoder=False,
dictionary=task.dictionary, dictionary=task.dictionary,
@ -259,14 +261,14 @@ class BertModel(BaseFairseqModel):
def build_embedding(cls, args, dictionary, embed_dim, path=None): def build_embedding(cls, args, dictionary, embed_dim, path=None):
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad())
return embed_tokens return embed_tokens
@classmethod @classmethod
def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight): def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
return LMHead(embed_dim, output_dim, activation_fn, weight) return LMHead(embed_dim, output_dim, activation_fn, weight)
def output_layer(self, features, masked_tokens=None): def output_layer(self, features, masked_tokens=None):
return self.encoder.output_projection(features, masked_tokens=masked_tokens) return self.encoder.output_projection(features, masked_tokens=masked_tokens)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head.""" """Register a classification head."""
if name in self.classification_heads: if name in self.classification_heads:
@ -286,12 +288,12 @@ class BertModel(BaseFairseqModel):
self.args.pooler_activation_fn, self.args.pooler_activation_fn,
self.args.pooler_dropout, self.args.pooler_dropout,
) )
def register_question_answering_head(self, name, num_classes=None): def register_question_answering_head(self, name, num_classes=None):
self.classification_heads[name] = SQuADHead( self.classification_heads[name] = SQuADHead(
self.args.encoder_embed_dim, self.args.encoder_embed_dim,
) )
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else '' prefix = name + '.' if name != '' else ''
@ -342,15 +344,16 @@ class BertModel(BaseFairseqModel):
if prefix + 'classification_heads.' + k not in state_dict: if prefix + 'classification_heads.' + k not in state_dict:
logger.info('Overwriting ' + prefix + 'classification_heads.' + k) logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v state_dict[prefix + 'classification_heads.' + k] = v
def forward( def forward(
self, self,
src_tokens=None, src_tokens=None,
features_only=False, features_only=False,
return_all_hiddens=False, return_all_hiddens=False,
classification_head_name=None, classification_head_name=None,
masked_tokens=None, masked_tokens=None,
**kwargs): **kwargs
):
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens) encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
x, extra = encoder_out["encoder_out"], encoder_out x, extra = encoder_out["encoder_out"], encoder_out
x = x.transpose(0, 1) x = x.transpose(0, 1)
@ -362,7 +365,7 @@ class BertModel(BaseFairseqModel):
return x, extra return x, extra
class ClassificationHead(nn.Module): class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
@ -389,6 +392,7 @@ class ClassificationHead(nn.Module):
x = self.out_proj(x) x = self.out_proj(x)
return x return x
class LMHead(nn.Module): class LMHead(nn.Module):
"""Head for masked language modeling.""" """Head for masked language modeling."""
@ -459,4 +463,4 @@ def base_unilm_architecture(args):
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False) args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations: if args.offload_activations:
args.checkpoint_activations = True args.checkpoint_activations = True

View File

@ -6,12 +6,12 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import torch import torch
from fairseq import options, utils from fairseq import utils
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import ( from fairseq.models import (
@ -29,9 +29,9 @@ from torchscale.architecture.config import DecoderConfig
from omegaconf import II from omegaconf import II
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass @dataclass
class LanguageConfig(FairseqDataclass): class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
@ -151,7 +151,10 @@ class LanguageConfig(FairseqDataclass):
moe_eval_capacity_token_fraction: Optional[float] = field( moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25, default=0.25,
metadata={ metadata={
"help": "Default: 0.25, Fraction of tokens as capacity during validation, if set to negative, use same as training. range: (0.0, 1.0]." "help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
} }
) )
moe_normalize_expert_grad: Optional[str] = field( moe_normalize_expert_grad: Optional[str] = field(
@ -164,7 +167,8 @@ class LanguageConfig(FairseqDataclass):
default=False, metadata={"help": "records all to all perf stats during distributed training"} default=False, metadata={"help": "records all to all perf stats during distributed training"}
) )
dummy_a2a: Optional[bool] = field( dummy_a2a: Optional[bool] = field(
default=False, metadata={"help": "By passes all to all during distributed training by returning the input buffer as output"} default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
) )
moe_batch_prioritized_routing: Optional[bool] = field( moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
@ -238,10 +242,10 @@ class LanguageModel(FairseqLanguageModel):
output_projection.weight = embed_tokens.weight output_projection.weight = embed_tokens.weight
else: else:
output_projection = torch.nn.Linear( output_projection = torch.nn.Linear(
decoder_embed_dim, len(task.dictionary), bias=False args.decoder_embed_dim, len(task.dictionary), bias=False
) )
torch.nn.init.normal_( torch.nn.init.normal_(
output_projection.weight, mean=0, std=decoder_embed_dim ** -0.5 output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
) )
if ( if (
@ -252,22 +256,23 @@ class LanguageModel(FairseqLanguageModel):
and getattr(args, 'ddp_backend', None) != "fully_sharded" and getattr(args, 'ddp_backend', None) != "fully_sharded"
) )
): ):
assert args.fp16_no_flatten_grads, "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" assert args.fp16_no_flatten_grads, \
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank() args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = DecoderConfig() config = DecoderConfig()
config.override(args) config.override(args)
decoder = LMDecoder( decoder = LMDecoder(
config, config,
embed_tokens, embed_tokens,
embed_positions, embed_positions,
output_projection, output_projection,
is_encoder_decoder=False, is_encoder_decoder=False,
dictionary=task.dictionary, dictionary=task.dictionary,
) )
return cls(args, decoder) return cls(args, decoder)
@classmethod @classmethod
@ -283,7 +288,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
def max_positions(self): def max_positions(self):
return self.embed_positions.max_positions return self.embed_positions.max_positions
def reorder_incremental_state_scripting( def reorder_incremental_state_scripting(
self, self,
incremental_state, incremental_state,
@ -294,6 +299,7 @@ class LMDecoder(Decoder, FairseqIncrementalDecoder):
result = incremental_state[module][key].index_select(0, new_order) result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result incremental_state[module][key] = result
@register_model_architecture("lm", "lm_base") @register_model_architecture("lm", "lm_base")
def base_lm_architecture(args): def base_lm_architecture(args):
# backward compatibility for older model checkpoints # backward compatibility for older model checkpoints
@ -357,4 +363,3 @@ def base_lm_architecture(args):
args.offload_activations = getattr(args, "offload_activations", False) args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations: if args.offload_activations:
args.checkpoint_activations = True args.checkpoint_activations = True

View File

@ -6,33 +6,20 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools from typing import Dict, List, Optional, Tuple
import math
from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn
from fairseq import utils from fairseq import utils
from fairseq.distributed import utils as dist_utils, fsdp_wrap from fairseq.distributed import utils as fsdp_wrap
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq import checkpoint_utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel, FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.transformer import Embedding from fairseq.models.transformer import Embedding
from fairseq.modules import ( from fairseq.modules import PositionalEmbedding
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder from .language_modeling import LMDecoder as MTDecoder
@ -164,18 +151,26 @@ class TranslationModel(FairseqEncoderDecoderModel):
help="Use FP32 computations in MoE top2 gating function") help="Use FP32 computations in MoE top2 gating function")
parser.add_argument('--moe-second-expert-policy', type=str, default='sampling', parser.add_argument('--moe-second-expert-policy', type=str, default='sampling',
help="policy for second expert, options: all/sampling/random") help="policy for second expert, options: all/sampling/random")
parser.add_argument('--moe-normalize-gate-prob-before-dropping', default=False, action='store_true', parser.add_argument(
help="whether to normalize gate probs before or after dropping experts for capacity and randomization") '--moe-normalize-gate-prob-before-dropping', default=False, action='store_true',
help=(
"whether to normalize gate probs before or after dropping experts "
"for capacity and randomization"
)
)
parser.add_argument('--moe-expert-ffn-dim', type=int, default=0, parser.add_argument('--moe-expert-ffn-dim', type=int, default=0,
help="MoE Expert FFN dimension") help="MoE Expert FFN dimension")
parser.add_argument('--moe-top1-expert', default=False, action='store_true', parser.add_argument('--moe-top1-expert', default=False, action='store_true',
help="Use top1 gate instead of top2") help="Use top1 gate instead of top2")
parser.add_argument('--moe-eval-capacity-token-fraction', type=float, default=0.25, parser.add_argument(
help="Fraction of tokens as capacity during validation" + \ '--moe-eval-capacity-token-fraction', type=float, default=0.25,
"if set to negative, use same as training. range: (0.0, 1.0].") help=(
"Fraction of tokens as capacity during validation"
"if set to negative, use same as training. range: (0.0, 1.0]."
)
)
parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size', parser.add_argument('--moe-normalize-expert-grad', type=str, default='world_size',
help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'") help="Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'")
parser.add_argument('--use-moe-pad-mask', default=False, action='store_true', parser.add_argument('--use-moe-pad-mask', default=False, action='store_true',
help="Don't route padding tokens to any expert") help="Don't route padding tokens to any expert")
parser.add_argument('--use-xmoe', default=False, action='store_true', parser.add_argument('--use-xmoe', default=False, action='store_true',
@ -207,7 +202,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None: if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
args.ddp_rank = distributed_utils.get_data_parallel_rank() args.ddp_rank = distributed_utils.get_data_parallel_rank()
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
@ -279,18 +274,18 @@ class TranslationModel(FairseqEncoderDecoderModel):
encoder = cls.build_encoder( encoder = cls.build_encoder(
args, args,
encoder_embed_tokens, encoder_embed_tokens,
encoder_embed_positions, encoder_embed_positions,
src_dict, src_dict,
) )
decoder = cls.build_decoder( decoder = cls.build_decoder(
args, args,
decoder_embed_tokens, decoder_embed_tokens,
decoder_embed_positions, decoder_embed_positions,
output_projection, output_projection,
tgt_dict, tgt_dict,
) )
if not args.share_all_embeddings: if not args.share_all_embeddings:
min_params_to_wrap = getattr( min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
@ -317,9 +312,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
config.override(args) config.override(args)
return MTEncoder( return MTEncoder(
config, config,
embed_tokens, embed_tokens,
embed_positions, embed_positions,
is_encoder_decoder=True, is_encoder_decoder=True,
dictionary=dictionary, dictionary=dictionary,
) )
@ -330,8 +325,8 @@ class TranslationModel(FairseqEncoderDecoderModel):
config.override(args) config.override(args)
return MTDecoder( return MTDecoder(
config, config,
embed_tokens, embed_tokens,
embed_positions, embed_positions,
output_projection, output_projection,
is_encoder_decoder=True, is_encoder_decoder=True,
@ -348,7 +343,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
**kwargs **kwargs
): ):
encoder_out = self.encoder( encoder_out = self.encoder(
src_tokens, src_tokens,
return_all_hiddens=return_all_hiddens return_all_hiddens=return_all_hiddens
) )
decoder_out = self.decoder( decoder_out = self.decoder(
@ -395,6 +390,7 @@ class MTEncoder(Encoder, FairseqEncoder):
def max_positions(self): def max_positions(self):
return self.embed_positions.max_positions return self.embed_positions.max_positions
@register_model_architecture("mt", "mt_base") @register_model_architecture("mt", "mt_base")
def base_architecture(args): def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) args.encoder_embed_path = getattr(args, "encoder_embed_path", None)

View File

@ -32,4 +32,4 @@ for file in os.listdir(tasks_dir):
# fmt: on # fmt: on
group_args = parser.add_argument_group("Additional command-line arguments") group_args = parser.add_argument_group("Additional command-line arguments")
TASK_REGISTRY[task_name].add_args(group_args) TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser globals()[task_name + "_parser"] = parser

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -1,14 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math
import re
import sys
import time
import torch import torch
from infinibatch.iterators import CheckpointableIterator from infinibatch.iterators import CheckpointableIterator
from . import utils from . import utils
class BaseBatchGen(CheckpointableIterator): class BaseBatchGen(CheckpointableIterator):
""" """
This is a base class for batch generators that use infinibatch This is a base class for batch generators that use infinibatch
@ -26,7 +23,7 @@ class BaseBatchGen(CheckpointableIterator):
Build infinibatch iterator and assign to self._iter Build infinibatch iterator and assign to self._iter
""" """
raise NotImplementedError() raise NotImplementedError()
def _move_to_tensor(self, batch): def _move_to_tensor(self, batch):
def to_tensor(x): def to_tensor(x):
@ -47,16 +44,16 @@ class BaseBatchGen(CheckpointableIterator):
def __next__(self): def __next__(self):
return next(self._iter) return next(self._iter)
def setstate(self, value): def setstate(self, value):
self._iter.setstate(value) self._iter.setstate(value)
def getstate(self): def getstate(self):
return self._iter.getstate() return self._iter.getstate()
def close(self): def close(self):
self._iter.close() self._iter.close()
def __len__(self) -> int: def __len__(self) -> int:
return 819200000 return 819200000
@ -78,4 +75,4 @@ class BaseBatchGen(CheckpointableIterator):
@property @property
def first_batch(self): def first_batch(self):
return "DUMMY" return "DUMMY"

View File

@ -1,13 +1,8 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import glob
import os import os
import torch
import numpy as np import numpy as np
import time
import json
import random
import itertools import itertools
import copy import copy
@ -55,15 +50,15 @@ class MLMLoader(BaseBatchGen):
self.batch_read_ahead = args.batch_read_ahead self.batch_read_ahead = args.batch_read_ahead
self._build_iter() self._build_iter()
def _build_iter(self): def _build_iter(self):
tokenized_lines = self._multilingual_tokenize() tokenized_lines = self._multilingual_tokenize()
self.padded_batches = self._batchify(tokenized_lines) self.padded_batches = self._batchify(tokenized_lines)
prefetch_batches = iterators.PrefetchIterator( prefetch_batches = iterators.PrefetchIterator(
self.padded_batches, self.padded_batches,
buffer_size=10000, buffer_size=10000,
buffer_in_main_process=True, buffer_in_main_process=True,
log_empty_buffer_warning=True and self.shard_id == 0, log_empty_buffer_warning=True and self.shard_id == 0,
) )
@ -85,14 +80,14 @@ class MLMLoader(BaseBatchGen):
weights.append(float(data['weight'])) weights.append(float(data['weight']))
else: else:
weights.append(int(data['count'])) weights.append(int(data['count']))
if len(multilingual_iters) == 1: if len(multilingual_iters) == 1:
return multilingual_iters[0] return multilingual_iters[0]
sampling_iterator = WeightIterator(weights) sampling_iterator = WeightIterator(weights)
control_iterator = NativeCheckpointableIterator(sampling_iterator) control_iterator = NativeCheckpointableIterator(sampling_iterator)
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters) tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
return tokenized_lines return tokenized_lines
def _tokenize(self, data): def _tokenize(self, data):
@ -109,7 +104,7 @@ class MLMLoader(BaseBatchGen):
dataset = list( dataset = list(
zip( zip(
data['source'], data['source'],
itertools.repeat(data['source_lang']), itertools.repeat(data['source_lang']),
) )
) )
@ -117,27 +112,26 @@ class MLMLoader(BaseBatchGen):
chunk_files = \ chunk_files = \
iterators.InfinitePermutationSourceIterator( iterators.InfinitePermutationSourceIterator(
dataset, dataset,
seed=self.seed, seed=self.seed,
shuffle=self.shuffle, shuffle=self.shuffle,
num_instances=self.num_shards, num_instances=self.num_shards,
instance_rank=self.shard_id, instance_rank=self.shard_id,
) )
else: else:
chunk_files = \ chunk_files = \
iterators.ChunkedSourceIterator( iterators.ChunkedSourceIterator(
dataset, dataset,
num_instances=self.num_shards, num_instances=self.num_shards,
instance_rank=self.shard_id, instance_rank=self.shard_id,
) )
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files)) tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed) tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
return tokenized_lines return tokenized_lines
def _batchify(self, lines): def _batchify(self, lines):
if self.max_sentences is not None: if self.max_sentences is not None:
if self.batch_read_ahead > 0: if self.batch_read_ahead > 0:
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed) lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
@ -145,14 +139,15 @@ class MLMLoader(BaseBatchGen):
else: else:
def dynamic_batch_size(sample): def dynamic_batch_size(sample):
lengths = [len(x) for x in sample] lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths) // self.required_batch_size_multiple * self.required_batch_size_multiple batch_size = self.max_tokens // max(lengths)
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
return max(1, batch_size) return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator( batches = iterators.BucketedReadaheadBatchIterator(
lines, lines,
read_ahead=self.batch_read_ahead, read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None, key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size, batch_size=dynamic_batch_size,
shuffle=self.shuffle, shuffle=self.shuffle,
seed=self.seed, seed=self.seed,
) )
@ -166,15 +161,15 @@ class MLMLoader(BaseBatchGen):
s2s_target_max_length = max([len(x[3]) for x in batch]) s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32, mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32, mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32, s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad())
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch): for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
@ -182,7 +177,7 @@ class MLMLoader(BaseBatchGen):
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:] s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1] s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
ret_batch = { ret_batch = {
'net_input': { 'net_input': {
'src_tokens': mlm_source_ids.astype(np.int64), 'src_tokens': mlm_source_ids.astype(np.int64),
@ -199,16 +194,16 @@ class MLMLoader(BaseBatchGen):
) )
return padded_batches return padded_batches
def _prepare(self, _random, doc): def _prepare(self, _random, doc):
nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc) nonmasked_tokens, masked_tokens = self._mask_lm(_random, doc)
nonnoise_spans, noise_spans = self._span_corruption(_random, doc) nonnoise_spans, noise_spans = self._span_corruption(_random, doc)
return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans return nonmasked_tokens, masked_tokens, nonnoise_spans, noise_spans
def _mask_lm(self, _random, doc): def _mask_lm(self, _random, doc):
def mask_tokens(): def mask_tokens():
return f"<mask>" return "<mask>"
length = len(doc) length = len(doc)
mask_tokens_num = int(length * self.args.mask_prob) mask_tokens_num = int(length * self.args.mask_prob)
mask_tokens_num = min(max(mask_tokens_num, 1), length - 1) mask_tokens_num = min(max(mask_tokens_num, 1), length - 1)
@ -222,11 +217,11 @@ class MLMLoader(BaseBatchGen):
# masked_tokens.append(nonmasked_tokens[position]) # masked_tokens.append(nonmasked_tokens[position])
masked_tokens[position] = nonmasked_tokens[position] masked_tokens[position] = nonmasked_tokens[position]
nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()] nonmasked_tokens[position] = self.dictionary.indices[mask_tokens()]
return nonmasked_tokens, masked_tokens return nonmasked_tokens, masked_tokens
def _span_corruption(self, _random, doc): def _span_corruption(self, _random, doc):
def mask_tokens(i): def mask_tokens(i):
return f"<mask_{i}>" return f"<mask_{i}>"
@ -244,7 +239,7 @@ class MLMLoader(BaseBatchGen):
_random.shuffle(possible_split_positions) _random.shuffle(possible_split_positions)
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1]) noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num] noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
possible_insert_positions = list(range(nonnoise_tokens_num)) possible_insert_positions = list(range(nonnoise_tokens_num))
_random.shuffle(possible_insert_positions) _random.shuffle(possible_insert_positions)
noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num]) noise_insert_positions = sorted(possible_insert_positions[:noise_spans_num])
@ -260,14 +255,14 @@ class MLMLoader(BaseBatchGen):
noise_spans.append(doc[start_pos:end_pos]) noise_spans.append(doc[start_pos:end_pos])
else: else:
noise_spans.append([mask_id] + doc[start_pos:end_pos]) noise_spans.append([mask_id] + doc[start_pos:end_pos])
if getattr(self.args, "remove_source_sentinel", False): if getattr(self.args, "remove_source_sentinel", False):
nonnoise_spans.extend(doc[last_end:start_pos]) nonnoise_spans.extend(doc[last_end:start_pos])
else: else:
nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id]) nonnoise_spans.extend(doc[last_end:start_pos] + [mask_id])
last_end = end_pos last_end = end_pos
nonnoise_spans.extend(doc[last_end:]) nonnoise_spans.extend(doc[last_end:])
noise_spans = sum(noise_spans, []) noise_spans = sum(noise_spans, [])
@ -276,10 +271,10 @@ class MLMLoader(BaseBatchGen):
def _read_from_files(self, source_file, source_lang): def _read_from_files(self, source_file, source_lang):
# data = [] # data = []
file_path = os.path.join(self.data_dir, source_file) file_path = os.path.join(self.data_dir, source_file)
if not os.path.exists(file_path): if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True) print('| file {} not exists'.format(file_path), flush=True)
return iter([]) # skip bad file return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f: with open(file_path, 'r', encoding='utf8') as f:
lines = f.read().strip().split('\n') lines = f.read().strip().split('\n')
@ -292,7 +287,7 @@ class MLMLoader(BaseBatchGen):
yield doc yield doc
doc = [self.dictionary.bos()] doc = [self.dictionary.bos()]
continue continue
tokenized_line = self.tokenizer.EncodeAsPieces(line) tokenized_line = self.tokenizer.EncodeAsPieces(line)
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index] tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
@ -308,4 +303,4 @@ class MLMLoader(BaseBatchGen):
# data.append(doc) # data.append(doc)
yield doc yield doc
# return data # return data

View File

@ -1,14 +1,13 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import os
import gzip
import numpy as np import numpy as np
from random import Random from random import Random
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterable, Optional
import collections import collections
from infinibatch import iterators from infinibatch import iterators
def apply_to_sample(f, sample): def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0: if hasattr(sample, "__len__") and len(sample) == 0:
return {} return {}
@ -34,6 +33,7 @@ def apply_to_sample(f, sample):
return _apply(sample) return _apply(sample)
class NativeCheckpointableIterator(iterators.CheckpointableIterator): class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def __init__(self, iterable: Iterable): def __init__(self, iterable: Iterable):
self._input_iterable = iterable self._input_iterable = iterable
@ -44,13 +44,16 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
def setstate(self, checkpoint: Optional[Dict]): def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable) self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator(self._iterator, checkpoint['num_items_yielded']) if checkpoint is not None else 0 self._num_items_yielded = iterators._advance_iterator(
self._iterator,
checkpoint['num_items_yielded']
) if checkpoint is not None else 0
def __next__(self): def __next__(self):
item = next(self._iterator) item = next(self._iterator)
self._num_items_yielded += 1 self._num_items_yielded += 1
return item return item
def close(self): def close(self):
pass pass
@ -61,17 +64,17 @@ class WeightIterator(object):
self.seed = seed self.seed = seed
self.control_index = list(range(len(weights))) self.control_index = list(range(len(weights)))
self.setstate(None) self.setstate(None)
def __iter__(self): def __iter__(self):
return self return self
def getstate(self): def getstate(self):
return {"random_state": self._random_state} return {"random_state": self._random_state}
def setstate(self, checkpoint): def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None self._random_state = checkpoint["random_state"] if checkpoint else None
self._random = None # this will trigger the lazy initialization in self.__next__ self._random = None # this will trigger the lazy initialization in self.__next__
def __next__(self): def __next__(self):
if self._random is None: if self._random is None:
self._random = Random(self.seed) self._random = Random(self.seed)
@ -80,6 +83,6 @@ class WeightIterator(object):
idx = self._random.choices(self.control_index, self.weights)[0] idx = self._random.choices(self.control_index, self.weights)[0]
self._random_state = self._random.getstate() self._random_state = self._random.getstate()
return idx return idx
def close(self): def close(self):
pass pass

View File

@ -10,16 +10,13 @@ import logging
import os import os
from argparse import Namespace from argparse import Namespace
import json import json
from omegaconf import MISSING, II, OmegaConf from omegaconf import MISSING, II
from typing import Any
import numpy as np
from fairseq import utils from fairseq import utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import FairseqTask, register_task
from .data.mlm_loader import MLMLoader from .data.mlm_loader import MLMLoader
from fairseq.dataclass import FairseqDataclass, ChoiceEnum from fairseq.dataclass import FairseqDataclass, ChoiceEnum
from fairseq.data.encoders.sentencepiece_bpe import SentencepieceBPE
import sentencepiece as spm import sentencepiece as spm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,6 +24,7 @@ logger = logging.getLogger(__name__)
SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"])
SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"])
@dataclass @dataclass
class PretrainingConfig(FairseqDataclass): class PretrainingConfig(FairseqDataclass):
data: str = field( data: str = field(
@ -163,11 +161,11 @@ class PLMTask(FairseqTask):
'shuffle': True if split == 'train' else False, 'shuffle': True if split == 'train' else False,
} }
self.datasets[split] = Namespace(**self.datasets[split]) self.datasets[split] = Namespace(**self.datasets[split])
def dataset(self, split): def dataset(self, split):
if split not in self.datasets: if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split) raise KeyError("Dataset not loaded: " + split)
return self.datasets[split] return self.datasets[split]
def get_batch_iterator( def get_batch_iterator(
@ -207,4 +205,4 @@ class PLMTask(FairseqTask):
@property @property
def target_dictionary(self): def target_dictionary(self):
return self.dictionary return self.dictionary

View File

@ -1,11 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.train import cli_main from fairseq_cli.train import cli_main
if __name__ == "__main__": if __name__ == "__main__":
cli_main() cli_main()

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -7,6 +7,7 @@ from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
import torch.distributed as dist import torch.distributed as dist
import math import math
@torch.no_grad() @torch.no_grad()
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor: def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
def grad_exists(p): def grad_exists(p):
@ -75,4 +76,4 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1)
for g in grads + expert_grads + sharded_grads + base_expert_grads: for g in grads + expert_grads + sharded_grads + base_expert_grads:
g.mul_(clip_coef) g.mul_(clip_coef)
return total_norm return total_norm

View File

@ -25,4 +25,4 @@ setup(
classifiers=[ classifiers=[
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
], ],
) )

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -23,6 +23,7 @@ testcases = [
{"fsdp": True} {"fsdp": True}
] ]
@pytest.mark.parametrize("args", testcases) @pytest.mark.parametrize("args", testcases)
def test_decoder(args): def test_decoder(args):
config = DecoderConfig(**args) config = DecoderConfig(**args)

View File

@ -23,6 +23,7 @@ testcases = [
{"fsdp": True} {"fsdp": True}
] ]
@pytest.mark.parametrize("args", testcases) @pytest.mark.parametrize("args", testcases)
def test_encoder(args): def test_encoder(args):
config = EncoderConfig(**args) config = EncoderConfig(**args)

View File

@ -25,13 +25,14 @@ testcases = [
{"fsdp": True} {"fsdp": True}
] ]
@pytest.mark.parametrize("args", testcases) @pytest.mark.parametrize("args", testcases)
def test_decoder(args): def test_decoder(args):
config = EncoderDecoderConfig(**args) config = EncoderDecoderConfig(**args)
model = EncoderDecoder( model = EncoderDecoder(
config, config,
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim), encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim), decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
) )
@ -41,6 +42,6 @@ def test_decoder(args):
model( model(
src_tokens=src_tokens, src_tokens=src_tokens,
prev_output_tokens=prev_output_tokens, prev_output_tokens=prev_output_tokens,
features_only=True, features_only=True,
) )

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]

View File

@ -355,7 +355,8 @@ def top2gating(
if has_tutel: if has_tutel:
locations1_s = torch.sum(locations1 * mask1_, dim=1) locations1_s = torch.sum(locations1 * mask1_, dim=1)
locations2_s = torch.sum(locations2 * mask2_, dim=1) locations2_s = torch.sum(locations2 * mask2_, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s] return l_aux, metadata, capacity, num_experts, \
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
# Store the capacity location for each token # Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)

View File

@ -1,2 +1,2 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]