Code reformatting

pull/5/head
shumingma 2022-11-26 09:01:02 +07:00
parent 1354614d44
commit 7eca1a531c
29 changed files with 775 additions and 557 deletions

@ -4,7 +4,6 @@
# flake8: noqa
import models
import tasks
from fairseq_cli.generate import cli_main
if __name__ == "__main__":

@ -4,7 +4,6 @@
# flake8: noqa
import models
import tasks
from fairseq_cli.interactive import cli_main
if __name__ == "__main__":

@ -2,24 +2,24 @@
# Licensed under The MIT License [see LICENSE for details]
import logging
from typing import Optional
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from fairseq import utils
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
)
from fairseq.modules import PositionalEmbedding
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.models.squad import SQuADHead
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from fairseq.modules import PositionalEmbedding
from omegaconf import II
from .machine_translation import MTEncoder as Encoder
from torchscale.architecture.config import EncoderConfig
from apex.normalization import FusedLayerNorm as LayerNorm
from .machine_translation import MTEncoder as Encoder
DEFAULT_MAX_SOURCE_POSITIONS = 1024
@ -109,7 +109,7 @@ class BertConfig(FairseqDataclass):
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
}
},
)
max_source_positions: int = field(
default=1024, metadata={"help": "max source positions"}
@ -118,59 +118,41 @@ class BertConfig(FairseqDataclass):
default="relu", metadata={"help": "activation function to use for pooler layer"}
)
pooler_dropout: float = field(
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
default=0.0,
metadata={"help": "dropout probability in the masked_lm pooler layers"},
)
# options from other parts of the config
# add_bos_token: bool = II("task.add_bos_token")
# tokens_per_sample: int = II("task.tokens_per_sample")
tpu: bool = II("common.tpu")
rel_pos_buckets: int = field(
default=0, metadata={"help": ""}
)
max_rel_pos: int = field(
default=0, metadata={"help": ""}
)
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
max_rel_pos: int = field(default=0, metadata={"help": ""})
moe_freq: int = field(
default=0,
metadata={
"help": "Frequency at which we insert MoE Transformer layers"
},
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
)
moe_expert_count: int = field(
default=0,
metadata={
"help": "Number of experts in each MoE Layer"
}
default=0, metadata={"help": "Number of experts in each MoE Layer"}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={
"help": "Use FP32 computations in MoE top2 gating function"
}
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
)
moe_second_expert_policy: str = field(
default='sampling',
metadata={
"help": "policy for second expert, options: all/sampling/random"
}
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
}
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
)
moe_expert_ffn_dim: Optional[int] = field(
default=None,
metadata={
"help": "MoE expert FFN dimension"
}
default=None, metadata={"help": "MoE expert FFN dimension"}
)
moe_top1_expert: Optional[bool] = field(
default=False,
metadata={
"help": "Use top1 gate instead of top2"
}
default=False, metadata={"help": "Use top1 gate instead of top2"}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
@ -179,23 +161,29 @@ class BertConfig(FairseqDataclass):
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
}
},
)
moe_normalize_expert_grad: Optional[str] = field(
default='world_size',
default="world_size",
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
}
},
)
record_a2a_perf_stats: Optional[bool] = field(
default=False, metadata={"help": "records all to all perf stats during distributed training"}
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
)
dummy_a2a: Optional[bool] = field(
default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
)
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
@ -208,7 +196,6 @@ class BertConfig(FairseqDataclass):
@register_model("mlm", dataclass=BertConfig)
class BertModel(BaseFairseqModel):
def __init__(self, args, encoder):
super().__init__()
self.args = args
@ -240,7 +227,11 @@ class BertModel(BaseFairseqModel):
)
lm_head = cls.build_lm_head(
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
args,
args.encoder_embed_dim,
len(task.dictionary),
args.activation_fn,
weight=embed_tokens.weight,
)
config = EncoderConfig()
@ -269,7 +260,9 @@ class BertModel(BaseFairseqModel):
def output_layer(self, features, masked_tokens=None):
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
@ -277,7 +270,7 @@ class BertModel(BaseFairseqModel):
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format(
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
@ -295,42 +288,51 @@ class BertModel(BaseFairseqModel):
)
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
prefix = name + "." if name != "" else ""
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict.
current_head_names = (
[] if not hasattr(self, 'classification_heads')
[]
if not hasattr(self, "classification_heads")
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
if not k.startswith(prefix + "classification_heads."):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203
num_classes = state_dict[
prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if getattr(self.args, 'load_checkpoint_heads', False):
if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
'deleting classification head ({}) from checkpoint '
'not present in current model: {}'.format(head_name, k)
"deleting classification head ({}) from checkpoint "
"not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
num_classes
!= self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
):
logger.warning(
'deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}".format(
head_name, k
)
)
keys_to_delete.append(k)
for k in keys_to_delete:
@ -338,12 +340,12 @@ class BertModel(BaseFairseqModel):
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
def forward(
self,
@ -354,7 +356,9 @@ class BertModel(BaseFairseqModel):
masked_tokens=None,
**kwargs
):
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
encoder_out = self.encoder(
src_tokens, features_only=True, return_all_hiddens=return_all_hiddens
)
x, extra = encoder_out["encoder_out"], encoder_out
x = x.transpose(0, 1)
@ -455,7 +459,7 @@ def base_unilm_architecture(args):
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
# Model training is not stable without this
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)

@ -9,10 +9,9 @@
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import utils
from fairseq import distributed_utils
import torch
from fairseq import distributed_utils, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
@ -20,14 +19,13 @@ from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from fairseq.modules import PositionalEmbedding
from torchscale.architecture.decoder import Decoder
from torchscale.architecture.config import DecoderConfig
from omegaconf import II
from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@ -104,49 +102,34 @@ class LanguageConfig(FairseqDataclass):
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
}
},
)
moe_freq: int = field(
default=0,
metadata={
"help": "Frequency at which we insert MoE Transformer layers"
},
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
)
moe_expert_count: int = field(
default=0,
metadata={
"help": "Number of experts in each MoE Layer"
}
default=0, metadata={"help": "Number of experts in each MoE Layer"}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={
"help": "Use FP32 computations in MoE top2 gating function"
}
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
)
moe_second_expert_policy: str = field(
default='sampling',
metadata={
"help": "policy for second expert, options: all/sampling/random"
}
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
}
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
)
moe_expert_ffn_dim: Optional[int] = field(
default=None,
metadata={
"help": "MoE expert FFN dimension"
}
default=None, metadata={"help": "MoE expert FFN dimension"}
)
moe_top1_expert: Optional[bool] = field(
default=False,
metadata={
"help": "Use top1 gate instead of top2"
}
default=False, metadata={"help": "Use top1 gate instead of top2"}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
@ -155,23 +138,29 @@ class LanguageConfig(FairseqDataclass):
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
}
},
)
moe_normalize_expert_grad: Optional[str] = field(
default='world_size',
default="world_size",
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
}
},
)
record_a2a_perf_stats: Optional[bool] = field(
default=False, metadata={"help": "records all to all perf stats during distributed training"}
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
)
dummy_a2a: Optional[bool] = field(
default=False, metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"}
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
)
use_xmoe: Optional[bool] = field(
default=False,
@ -205,7 +194,6 @@ class LanguageConfig(FairseqDataclass):
@register_model("lm", dataclass=LanguageConfig)
class LanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@ -245,19 +233,17 @@ class LanguageModel(FairseqLanguageModel):
args.decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
if (
getattr(args, 'moe_freq', 0) > 0
and (
getattr(args, 'fp16', False)
and not getattr(args, 'memory_efficient_fp16', False)
and getattr(args, 'ddp_backend', None) != "fully_sharded"
)
if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, "fp16", False)
and not getattr(args, "memory_efficient_fp16", False)
and getattr(args, "ddp_backend", None) != "fully_sharded"
):
assert args.fp16_no_flatten_grads, \
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
assert (
args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
@ -281,7 +267,6 @@ class LanguageModel(FairseqLanguageModel):
class LMDecoder(Decoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)

@ -6,12 +6,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List, Optional, Tuple
import torch
from fairseq import utils
from fairseq import distributed_utils, utils
from fairseq.distributed import utils as fsdp_wrap
from fairseq import distributed_utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
@ -20,12 +20,13 @@ from fairseq.models import (
)
from fairseq.models.transformer import Embedding
from fairseq.modules import PositionalEmbedding
from torch import Tensor
from torchscale.architecture.config import DecoderConfig, EncoderConfig
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder
from torch import Tensor
import logging
logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
@ -35,7 +36,6 @@ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("mt")
class TranslationModel(FairseqEncoderDecoderModel):
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
@ -269,7 +269,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
args.decoder_embed_dim, len(tgt_dict), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
encoder = cls.build_encoder(
@ -320,7 +320,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
)
@classmethod
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary
):
config = DecoderConfig()
config.override(args)
@ -342,10 +344,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
features_only: bool = False,
**kwargs
):
encoder_out = self.encoder(
src_tokens,
return_all_hiddens=return_all_hiddens
)
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
@ -365,15 +364,20 @@ class TranslationModel(FairseqEncoderDecoderModel):
class MTEncoder(Encoder, FairseqEncoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
return super().forward(
src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs
)
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
0, new_order
)
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(
0, new_order
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:

@ -3,6 +3,7 @@
import torch
from infinibatch.iterators import CheckpointableIterator
from . import utils
@ -25,7 +26,6 @@ class BaseBatchGen(CheckpointableIterator):
raise NotImplementedError()
def _move_to_tensor(self, batch):
def to_tensor(x):
return torch.tensor(x)

@ -1,32 +1,32 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import os
import numpy as np
import itertools
import copy
import itertools
import os
import numpy as np
from infinibatch import iterators
from .basic_loader import BaseBatchGen
from .utils import NativeCheckpointableIterator, WeightIterator
class MLMLoader(BaseBatchGen):
def __init__(
self,
args,
dataset,
dictionary,
tokenizer,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
self,
args,
dataset,
dictionary,
tokenizer,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
):
super().__init__()
self.args = args
@ -62,9 +62,7 @@ class MLMLoader(BaseBatchGen):
log_empty_buffer_warning=True and self.shard_id == 0,
)
prefetch_batches = iterators.MapIterator(
prefetch_batches, self._move_to_tensor
)
prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)
self._iter = prefetch_batches
@ -73,25 +71,25 @@ class MLMLoader(BaseBatchGen):
weights = []
for data in self.data:
multilingual_iters.append(
self._tokenize(data)
)
if 'weight' in data:
weights.append(float(data['weight']))
multilingual_iters.append(self._tokenize(data))
if "weight" in data:
weights.append(float(data["weight"]))
else:
weights.append(int(data['count']))
weights.append(int(data["count"]))
if len(multilingual_iters) == 1:
return multilingual_iters[0]
sampling_iterator = WeightIterator(weights)
control_iterator = NativeCheckpointableIterator(sampling_iterator)
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
tokenized_lines = iterators.MultiplexIterator(
control_iterator, multilingual_iters
)
return tokenized_lines
def _tokenize(self, data):
'''
"""
data:
{
'source': list[Path],
@ -100,33 +98,35 @@ class MLMLoader(BaseBatchGen):
'weight': float,
'name': str,
}
'''
"""
dataset = list(
zip(
data['source'],
itertools.repeat(data['source_lang']),
)
zip(
data["source"],
itertools.repeat(data["source_lang"]),
)
)
if self.shuffle:
chunk_files = \
iterators.InfinitePermutationSourceIterator(
dataset,
seed=self.seed,
shuffle=self.shuffle,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
chunk_files = iterators.InfinitePermutationSourceIterator(
dataset,
seed=self.seed,
shuffle=self.shuffle,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
else:
chunk_files = \
iterators.ChunkedSourceIterator(
dataset,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
chunk_files = iterators.ChunkedSourceIterator(
dataset,
num_instances=self.num_shards,
instance_rank=self.shard_id,
)
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
tokenized_lines = iterators.SelectManyIterator(
chunk_files, lambda files: self._read_from_files(*files)
)
tokenized_lines = iterators.SamplingRandomMapIterator(
tokenized_lines, self._prepare, self.seed
)
return tokenized_lines
@ -134,22 +134,29 @@ class MLMLoader(BaseBatchGen):
if self.max_sentences is not None:
if self.batch_read_ahead > 0:
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
lines = iterators.BlockwiseShuffleIterator(
lines, self.batch_read_ahead, self.seed
)
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
else:
def dynamic_batch_size(sample):
lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths)
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
batch_size = (
batch_size
// self.required_batch_size_multiple
* self.required_batch_size_multiple
)
return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator(
lines,
read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size,
shuffle=self.shuffle,
seed=self.seed,
lines,
read_ahead=self.batch_read_ahead,
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
batch_size=dynamic_batch_size,
shuffle=self.shuffle,
seed=self.seed,
)
def collate(batch):
@ -160,38 +167,56 @@ class MLMLoader(BaseBatchGen):
s2s_source_max_length = max([len(x[2]) for x in batch])
s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
fill_value=self.dictionary.pad())
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
mlm_source_ids = np.full(
shape=(batch_size, mlm_source_max_length),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
mlm_target_ids = np.full(
shape=(batch_size, mlm_target_max_length),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
s2s_source_ids = np.full(
shape=(batch_size, s2s_source_max_length),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
s2s_target_ids = np.full(
shape=(batch_size, s2s_target_max_length - 1),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
s2s_prev_input_ids = np.full(
shape=(batch_size, s2s_target_max_length - 1),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
for i, (
mlm_input_ids,
mlm_label_ids,
s2s_input_ids,
s2s_label_ids,
) in enumerate(batch):
mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]
ret_batch = {
'net_input': {
'src_tokens': mlm_source_ids.astype(np.int64),
"net_input": {
"src_tokens": mlm_source_ids.astype(np.int64),
},
'target': mlm_target_ids.astype(np.int64),
'nsentences': batch_size,
'ntokens': sum([len(x[0]) for x in batch]),
"target": mlm_target_ids.astype(np.int64),
"nsentences": batch_size,
"ntokens": sum([len(x[0]) for x in batch]),
}
return ret_batch
padded_batches = iterators.MapIterator(
batches, collate
)
padded_batches = iterators.MapIterator(batches, collate)
return padded_batches
@ -221,7 +246,6 @@ class MLMLoader(BaseBatchGen):
return nonmasked_tokens, masked_tokens
def _span_corruption(self, _random, doc):
def mask_tokens(i):
return f"<mask_{i}>"
@ -237,7 +261,9 @@ class MLMLoader(BaseBatchGen):
else:
possible_split_positions = list(range(1, noise_tokens_num))
_random.shuffle(possible_split_positions)
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
noise_split_positions = sorted(
possible_split_positions[: noise_spans_num - 1]
)
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
possible_insert_positions = list(range(nonnoise_tokens_num))
@ -248,7 +274,7 @@ class MLMLoader(BaseBatchGen):
last_end = 0
for i in range(noise_spans_num):
start_pos = noise_insert_positions[i] + noise_split_positions[i]
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
mask_id = self.dictionary.indices[mask_tokens(i)]
if getattr(self.args, "remove_target_sentinel", False):
@ -273,23 +299,25 @@ class MLMLoader(BaseBatchGen):
file_path = os.path.join(self.data_dir, source_file)
if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True)
print("| file {} not exists".format(file_path), flush=True)
return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f:
lines = f.read().strip().split('\n')
with open(file_path, "r", encoding="utf8") as f:
lines = f.read().strip().split("\n")
doc = [self.dictionary.bos()]
for line in lines:
if line == "":
if self.sample_break_mode == 'complete_doc':
if self.sample_break_mode == "complete_doc":
# data.append(doc)
yield doc
doc = [self.dictionary.bos()]
continue
tokenized_line = self.tokenizer.EncodeAsPieces(line)
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
tokenized_id = [
self.dictionary.index(token) for token in tokenized_line
] + [self.dictionary.eos_index]
if len(tokenized_id) > self.tokens_per_sample:
continue

@ -1,10 +1,11 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import numpy as np
import collections
from random import Random
from typing import Dict, Iterable, Optional
import collections
import numpy as np
from infinibatch import iterators
@ -17,7 +18,9 @@ def apply_to_sample(f, sample):
return f(x)
elif isinstance(x, collections.OrderedDict):
# OrderedDict has attributes that needs to be preserved
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
od = collections.OrderedDict(
(key, _apply(value)) for key, value in x.items()
)
od.__dict__ = x.__dict__
return od
elif isinstance(x, dict):
@ -40,14 +43,15 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
self.setstate(None)
def getstate(self) -> Dict:
return {'num_items_yielded': self._num_items_yielded}
return {"num_items_yielded": self._num_items_yielded}
def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator(
self._iterator,
checkpoint['num_items_yielded']
) if checkpoint is not None else 0
self._num_items_yielded = (
iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
if checkpoint is not None
else 0
)
def __next__(self):
item = next(self._iterator)
@ -73,7 +77,9 @@ class WeightIterator(object):
def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None
self._random = None # this will trigger the lazy initialization in self.__next__
self._random = (
None # this will trigger the lazy initialization in self.__next__
)
def __next__(self):
if self._random is None:

@ -1,23 +1,25 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import json
import logging
import os
from argparse import Namespace
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import logging
import os
from argparse import Namespace
import json
from omegaconf import MISSING, II
import sentencepiece as spm
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II, MISSING
from .data.mlm_loader import MLMLoader
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
import sentencepiece as spm
logger = logging.getLogger(__name__)
@ -109,21 +111,16 @@ class PretrainingConfig(FairseqDataclass):
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
spm_model: str = field(
default="",
metadata={
"help": "sentencepice model to tokenize the data"
},
metadata={"help": "sentencepice model to tokenize the data"},
)
dict_file: str = field(
default="",
metadata={
"help": ""
},
metadata={"help": ""},
)
@register_task("pretraining", dataclass=PretrainingConfig)
class PLMTask(FairseqTask):
def __init__(self, cfg, dictionary, tokenizer):
super().__init__(cfg)
self.cfg = cfg
@ -156,9 +153,9 @@ class PLMTask(FairseqTask):
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
self.datasets[split] = {
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
'data_dir': self.cfg.data,
'shuffle': True if split == 'train' else False,
"data": json.load(open(f"{self.cfg.data}/json/{split}.json")),
"data_dir": self.cfg.data,
"shuffle": True if split == "train" else False,
}
self.datasets[split] = Namespace(**self.datasets[split])
@ -185,18 +182,18 @@ class PLMTask(FairseqTask):
disable_iterator_cache=False,
):
return MLMLoader(
self.cfg,
dataset,
self.dictionary,
self.tokenizer,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
self.cfg,
dataset,
self.dictionary,
self.tokenizer,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
)
@property

@ -4,7 +4,6 @@
# flake8: noqa
import models
import tasks
from fairseq_cli.train import cli_main
if __name__ == "__main__":

@ -1,17 +1,21 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch
import math
import warnings
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
import torch
import torch.distributed as dist
import math
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
@torch.no_grad()
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
def clip_grad_norm_(
params, max_norm, moe_expert_count, aggregate_norm_fn=None
) -> torch.Tensor:
def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
@ -59,7 +63,9 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
for split_grads in [expert_grads, sharded_grads]:
if len(split_grads) == 0:
continue
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]))
split_norm = torch.norm(
torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
)
if dist.is_initialized():
split_norm.pow_(2)
dist.all_reduce(split_norm)

@ -2,6 +2,7 @@
# Licensed under The MIT License [see LICENSE for details]
from io import open
from setuptools import find_packages, setup
setup(
@ -10,19 +11,15 @@ setup(
author="TorchScale Team",
author_email="Shuming.Ma@microsoft.com",
description="Transformers at any scale",
long_description=open("README.md", "r", encoding='utf-8').read(),
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="Transformers at any scale",
license="MIT",
url="https://github.com/msranlp/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]),
install_requires=['apex',
'torch>=1.8',
'fairscale==0.4.0',
'timm==0.4.12'],
python_requires='>=3.8.0',
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
python_requires=">=3.8.0",
classifiers=[
'Programming Language :: Python :: 3',
"Programming Language :: Python :: 3",
],
)

@ -2,9 +2,10 @@
# Licensed under The MIT License [see LICENSE for details]
import pytest
import torch
from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
import torch
testcases = [
{},
@ -20,7 +21,7 @@ testcases = [
{"multiway": True},
{"share_decoder_input_output_embed": True},
{"checkpoint_activations": True},
{"fsdp": True}
{"fsdp": True},
]

@ -2,9 +2,10 @@
# Licensed under The MIT License [see LICENSE for details]
import pytest
import torch
from torchscale.architecture.config import EncoderConfig
from torchscale.architecture.encoder import Encoder
import torch
testcases = [
{},
@ -20,7 +21,7 @@ testcases = [
{"multiway": True},
{"share_encoder_input_output_embed": True},
{"checkpoint_activations": True},
{"fsdp": True}
{"fsdp": True},
]

@ -2,10 +2,11 @@
# Licensed under The MIT License [see LICENSE for details]
import pytest
import torch
from torchscale.architecture.config import EncoderDecoderConfig
from torchscale.architecture.encoder_decoder import EncoderDecoder
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding
import torch
from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
testcases = [
{},
@ -16,13 +17,18 @@ testcases = [
{"no_scale_embedding": False},
{"layernorm_embedding": True},
{"rel_pos_buckets": 32, "max_rel_pos": 256},
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False},
{
"deepnorm": True,
"subln": False,
"encoder_normalize_before": False,
"decoder_normalize_before": False,
},
{"bert_init": True},
{"multiway": True},
{"share_decoder_input_output_embed": True},
{"share_all_embeddings": True},
{"checkpoint_activations": True},
{"fsdp": True}
{"fsdp": True},
]
@ -33,8 +39,12 @@ def test_decoder(args):
config,
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
encoder_embed_positions=PositionalEmbedding(
config.max_source_positions, config.encoder_embed_dim
),
decoder_embed_positions=PositionalEmbedding(
config.max_target_positions, config.decoder_embed_dim
),
)
src_tokens = torch.ones(2, 20).long()

@ -1,6 +1,7 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
class EncoderConfig(object):
def __init__(self, **kwargs):
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
@ -19,9 +20,13 @@ class EncoderConfig(object):
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
@ -29,7 +34,9 @@ class EncoderConfig(object):
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
self.share_encoder_input_output_embed = kwargs.pop(
"share_encoder_input_output_embed", False
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text
@ -78,9 +85,13 @@ class DecoderConfig(object):
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
@ -88,7 +99,9 @@ class DecoderConfig(object):
self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text
@ -138,9 +151,13 @@ class EncoderDecoderConfig(object):
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
@ -149,7 +166,9 @@ class EncoderDecoderConfig(object):
self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False)
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)

@ -2,22 +2,23 @@
# Licensed under The MIT License [see LICENSE for details]
import math
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from fairscale.nn import checkpoint_wrapper, wrap
from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.droppath import DropPath
from torchscale.architecture.utils import init_bert_params
from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
class DecoderLayer(nn.Module):
def __init__(
self,
args,
@ -31,7 +32,9 @@ class DecoderLayer(nn.Module):
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
@ -206,7 +209,6 @@ class DecoderLayer(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
args,
@ -228,7 +230,11 @@ class Decoder(nn.Module):
self.embed_tokens = embed_tokens
self.embed_positions = embed_positions
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
if (
output_projection is None
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
@ -286,7 +292,12 @@ class Decoder(nn.Module):
else:
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if args.subln:
@ -295,9 +306,14 @@ class Decoder(nn.Module):
else:
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
for name, p in self.named_parameters():
if 'encoder_attn' in name:
if "encoder_attn" in name:
continue
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
def build_output_projection(
@ -316,16 +332,12 @@ class Decoder(nn.Module):
args.decoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
return output_projection
def build_decoder_layer(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
):
layer = DecoderLayer(
args,
@ -347,7 +359,9 @@ class Decoder(nn.Module):
):
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(tokens, incremental_state=incremental_state)
positions = self.embed_positions(
tokens, incremental_state=incremental_state
)
if incremental_state is not None:
tokens = tokens[:, -1:]
@ -381,7 +395,9 @@ class Decoder(nn.Module):
**kwargs
):
# embed tokens and positions
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
x = x.transpose(0, 1)
# relative postion
@ -389,9 +405,7 @@ class Decoder(nn.Module):
slen = prev_output_tokens.size(1)
if self.self_attn_relative_position is not None:
self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(1),
qlen=slen,
klen=slen
batch_size=x.size(1), qlen=slen, klen=slen
)
if incremental_state is not None:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
@ -416,7 +430,11 @@ class Decoder(nn.Module):
for idx, layer in enumerate(self.layers):
if incremental_state is None:
self_attn_mask = torch.triu(
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1
torch.zeros([x.size(0), x.size(0)])
.float()
.fill_(float("-inf"))
.type_as(x),
1,
)
else:
self_attn_mask = None
@ -426,7 +444,9 @@ class Decoder(nn.Module):
x, layer_attn, _, l_aux_i = layer(
x,
encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"]
if encoder_out is not None
else None,
incremental_state[idx] if incremental_state is not None else None,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
@ -444,7 +464,11 @@ class Decoder(nn.Module):
if not features_only:
x = self.output_layer(x)
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]}
return x, {
"inner_states": inner_states,
"l_aux": l_aux,
"attn": [layer_attn.mean(dim=0)],
}
def output_layer(self, features):
return self.output_projection(features)

@ -2,30 +2,25 @@
# Licensed under The MIT License [see LICENSE for details]
import math
import numpy as np
import torch
import torch.nn as nn
import numpy as np
from fairscale.nn import checkpoint_wrapper, wrap
from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
from torchscale.component.droppath import DropPath
from torchscale.architecture.utils import init_bert_params
from torchscale.component.multiway_network import MultiwayWrapper, set_split_position
from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
class EncoderLayer(nn.Module):
def __init__(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
):
def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
super().__init__()
self.args = args
self.embed_dim = args.encoder_embed_dim
@ -34,7 +29,9 @@ class EncoderLayer(nn.Module):
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
@ -49,7 +46,7 @@ class EncoderLayer(nn.Module):
self.build_ffn(
self.embed_dim,
self.args,
)
),
)
else:
assert not self.args.multiway
@ -77,7 +74,12 @@ class EncoderLayer(nn.Module):
if args.deepnorm:
if is_encoder_decoder:
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
self.alpha = (
math.pow(
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
)
* 0.81
)
else:
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
else:
@ -107,13 +109,7 @@ class EncoderLayer(nn.Module):
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
encoder_padding_mask,
attn_mask=None,
rel_pos=None
):
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
@ -158,7 +154,6 @@ class EncoderLayer(nn.Module):
class Encoder(nn.Module):
def __init__(
self,
args,
@ -179,13 +174,20 @@ class Encoder(nn.Module):
self.embed_tokens = embed_tokens
self.embed_positions = embed_positions
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
if (
output_projection is None
and not is_encoder_decoder
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1)
self.layernorm_embedding = MultiwayWrapper(
args, LayerNorm(embed_dim), dim=1
)
else:
self.layernorm_embedding = None
@ -199,7 +201,7 @@ class Encoder(nn.Module):
args,
depth=i,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder
is_encoder_decoder=is_encoder_decoder,
)
)
self.num_layers = len(self.layers)
@ -223,20 +225,39 @@ class Encoder(nn.Module):
if args.deepnorm:
if is_encoder_decoder:
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
init_scale = (
math.pow(
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
)
/ 1.15
)
else:
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if args.subln:
if is_encoder_decoder:
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
init_scale = math.sqrt(
math.log(3 * args.decoder_layers)
* math.log(2 * args.encoder_layers)
/ 3
)
else:
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
def build_output_projection(
@ -244,7 +265,7 @@ class Encoder(nn.Module):
args,
):
if args.share_encoder_input_output_embed:
assert args.encoder_embedding_type == 'language'
assert args.encoder_embedding_type == "language"
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
@ -256,22 +277,18 @@ class Encoder(nn.Module):
args.encoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5
output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5
)
return output_projection
def build_encoder_layer(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
):
layer = EncoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder
is_encoder_decoder=is_encoder_decoder,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
@ -312,13 +329,12 @@ class Encoder(nn.Module):
if encoder_padding_mask is None:
if src_tokens is not None:
encoder_padding_mask = torch.zeros_like(
src_tokens,
device=src_tokens.device
src_tokens, device=src_tokens.device
).bool()
else:
encoder_padding_mask = torch.zeros(
[token_embeddings.size(0), token_embeddings.size(1)],
device=token_embeddings.device
device=token_embeddings.device,
).bool()
if multiway_split_position is not None:
@ -338,16 +354,13 @@ class Encoder(nn.Module):
rel_pos_bias = None
if self.relative_position is not None:
rel_pos_bias = self.relative_position(
batch_size=x.size(1),
qlen=x.size(0),
klen=x.size(0)
batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
)
l_aux = []
for layer in self.layers:
x, l_aux_i = layer(
x, encoder_padding_mask=encoder_padding_mask,
rel_pos=rel_pos_bias
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
)
if return_all_hiddens:
assert encoder_states is not None

@ -2,12 +2,12 @@
# Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.decoder import Decoder
from torchscale.architecture.encoder import Encoder
class EncoderDecoder(nn.Module):
def __init__(
self,
args,
@ -51,10 +51,7 @@ class EncoderDecoder(nn.Module):
features_only=False,
**kwargs
):
encoder_out = self.encoder(
src_tokens,
return_all_hiddens=return_all_hiddens
)
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,

@ -2,12 +2,12 @@
# Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn
from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.multiway_network import MultiwayNetwork
def init_bert_params(module):
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))

@ -1,13 +1,13 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
from timm.models.layers import drop_path
import torch.nn as nn
from timm.models.layers import drop_path
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
@ -16,4 +16,4 @@ class DropPath(nn.Module):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
return 'p={}'.format(self.drop_prob)
return "p={}".format(self.drop_prob)

@ -7,22 +7,12 @@ import torch.nn.functional as F
class VisionLanguageEmbedding(nn.Module):
def __init__(
self,
text_embed,
vision_embed
):
def __init__(self, text_embed, vision_embed):
super().__init__()
self.text_embed = text_embed
self.vision_embed = vision_embed
def forward(
self,
textual_tokens,
visual_tokens,
**kwargs
):
def forward(self, textual_tokens, visual_tokens, **kwargs):
if textual_tokens is None:
return self.vision_embed(visual_tokens)
@ -36,8 +26,8 @@ class VisionLanguageEmbedding(nn.Module):
class VisionEmbedding(nn.Module):
""" Image to Patch Embedding
"""
"""Image to Patch Embedding"""
def __init__(
self,
img_size=224,
@ -45,7 +35,7 @@ class VisionEmbedding(nn.Module):
in_chans=3,
embed_dim=768,
contain_mask_token=False,
prepend_cls_token=False
prepend_cls_token=False,
):
super().__init__()
img_size = (img_size, img_size)
@ -56,7 +46,9 @@ class VisionEmbedding(nn.Module):
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if contain_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -68,15 +60,11 @@ class VisionEmbedding(nn.Module):
else:
self.cls_token = None
def forward(
self,
x,
masked_position=None,
**kwargs
):
def forward(self, x, masked_position=None, **kwargs):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
batch_size, seq_len, _ = x.size()
@ -88,21 +76,21 @@ class VisionEmbedding(nn.Module):
x = x * (1 - w) + mask_token * w
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
return x
class TextEmbedding(nn.Embedding):
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
self._fill_padding_idx_with_zero()
class PositionalEmbedding(nn.Embedding):
def forward(
self,
x,
@ -111,7 +99,9 @@ class PositionalEmbedding(nn.Embedding):
):
if positions is None:
# being consistent with Fairseq, which starts from 2.
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0)
positions = (
torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
)
return F.embedding(
positions,
self.weight,

@ -35,13 +35,19 @@ class set_torch_seed(object):
def make_experts(args, embed_dim, expert_ffn_dim):
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
world_size = (
1
if not torch.distributed.is_initialized()
else torch.distributed.get_world_size()
)
expert_list = []
ddp_rank = args.ddp_rank
start_seed = torch.randint(1000000, (1,)).item()
# at least as many experts than gpus
if args.moe_expert_count >= world_size:
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}'
assert (
args.moe_expert_count % world_size == 0
), f"{args.moe_expert_count}, {world_size}"
local_moe_expert_count = args.moe_expert_count // world_size
for i in range(local_moe_expert_count):
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
@ -52,11 +58,13 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn,
args.dropout,
args.activation_dropout,
args.subln
args.subln,
)
)
else:
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}'
assert (
world_size % args.moe_expert_count == 0
), f"{world_size}, {args.moe_expert_count}"
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
expert_list.append(
@ -66,7 +74,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn,
args.dropout,
args.activation_dropout,
args.subln
args.subln,
)
)
experts = nn.ModuleList(expert_list)
@ -83,7 +91,6 @@ def get_activation_fn(activation):
class FeedForwardNetwork(nn.Module):
def __init__(
self,
embed_dim,
@ -91,12 +98,14 @@ class FeedForwardNetwork(nn.Module):
activation_fn,
dropout,
activation_dropout,
subln=False
subln=False,
):
super().__init__()
self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True)
self.activation_dropout_module = torch.nn.Dropout(
activation_dropout, inplace=True
)
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)

@ -2,15 +2,16 @@
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
from torch import nn
import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm
from torch import nn
from .multiway_network import MultiwayWrapper
class MultiheadAttention(nn.Module):
def __init__(
self,
args,
@ -25,7 +26,7 @@ class MultiheadAttention(nn.Module):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
@ -34,8 +35,14 @@ class MultiheadAttention(nn.Module):
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None
self.out_proj = MultiwayWrapper(
args, nn.Linear(embed_dim, embed_dim, bias=True)
)
self.inner_attn_ln = (
MultiwayWrapper(args, LayerNorm(self.embed_dim))
if subln and self.self_attention
else None
)
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
def reset_parameters(self):
@ -76,12 +83,20 @@ class MultiheadAttention(nn.Module):
if incremental_state is not None:
if "prev_key" in incremental_state:
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
prev_key = incremental_state["prev_key"].view(
bsz * self.num_heads, -1, self.head_dim
)
prev_value = incremental_state["prev_value"].view(
bsz * self.num_heads, -1, self.head_dim
)
k = torch.cat([prev_key, k], dim=1)
v = torch.cat([prev_value, v], dim=1)
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
incremental_state["prev_key"] = k.view(
bsz, self.num_heads, -1, self.head_dim
)
incremental_state["prev_value"] = v.view(
bsz, self.num_heads, -1, self.head_dim
)
src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
@ -103,7 +118,9 @@ class MultiheadAttention(nn.Module):
rel_pos = rel_pos.view(attn_weights.size())
attn_weights = attn_weights + rel_pos
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
attn_probs = self.dropout_module(attn_weights)
attn = torch.bmm(attn_probs, v)

@ -2,6 +2,7 @@
# Licensed under The MIT License [see LICENSE for details]
import copy
import torch
import torch.nn as nn
@ -13,16 +14,14 @@ def MultiwayWrapper(args, module, dim=0):
def set_split_position(position):
def apply_fn(module):
if hasattr(module, 'split_position'):
if hasattr(module, "split_position"):
module.split_position = position
return apply_fn
class MultiwayNetwork(nn.Module):
def __init__(self, module, dim=0):
super().__init__()
self.dim = dim
@ -36,7 +35,11 @@ class MultiwayNetwork(nn.Module):
return self.A(x, **kwargs)
if self.split_position == 0:
return self.B(x, **kwargs)
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim)
x1, x2 = torch.split(
x,
[self.split_position, x.size(self.dim) - self.split_position],
dim=self.dim,
)
# x1, x2 = x[:self.split_position], x[self.split_position:]
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
return torch.cat([y1, y2], dim=self.dim)

@ -2,17 +2,14 @@
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import torch.nn as nn
class RelativePositionBias(nn.Module):
def __init__(
self,
bidirectional=True,
num_buckets=32,
max_distance=128,
n_heads=12
self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
):
super().__init__()
self.bidirectional = bidirectional
@ -23,10 +20,7 @@ class RelativePositionBias(nn.Module):
@staticmethod
def _relative_position_bucket(
relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
ret = 0
n = -relative_position
@ -41,24 +35,28 @@ class RelativePositionBias(nn.Module):
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large)
return ret
def compute_bias(
self,
qlen,
klen,
step=None
):
def compute_bias(self, qlen, klen, step=None):
step = 0 if step is None else step
context_position = torch.arange(step, step + qlen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[None, :]
context_position = torch.arange(
step,
step + qlen,
dtype=torch.long,
device=self.relative_attention_bias.weight.device,
)[:, None]
memory_position = torch.arange(
klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
@ -67,16 +65,18 @@ class RelativePositionBias(nn.Module):
num_buckets=self.num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
values = self.relative_attention_bias(
rp_bucket
) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, qlen, klen)
return values
def forward(
self,
batch_size,
qlen,
klen,
step=None
):
def forward(self, batch_size, qlen, klen, step=None):
# shape (batch * num_heads, qlen, klen)
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
return (
self.compute_bias(qlen, klen, step)
.repeat(batch_size, 1, 1, 1)
.view(-1, qlen, klen)
)

@ -18,9 +18,9 @@ import torch.distributed as dist
from torch import Tensor
from torch.nn import Module, ModuleList
try:
from fairseq.modules.moe import MOELayer
has_fairseq = True
Base = MOELayer
except ModuleNotFoundError:
@ -81,8 +81,10 @@ def get_moe_group(moe_expert_count):
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)]
moe_groups = [
[i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
@ -105,11 +107,15 @@ def get_all2all_group(moe_expert_count):
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)]
all2all_groups = [
[i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups]
get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx]
@ -133,12 +139,7 @@ class MOELayer(Base):
expert network
"""
def __init__(
self,
gate,
experts,
args
):
def __init__(self, gate, experts, args):
if has_fairseq:
super(Base, self).__init__()
else:
@ -163,9 +164,13 @@ class MOELayer(Base):
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
input = input[0]
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
assert (
len(input.shape) == 3
), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
if input_padding_mask is not None:
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken"
assert (
len(input_padding_mask.shape) == 2
), "input Tensor must have dimensions: (s)equence, (t)oken"
assert input_padding_mask.shape[0] == input.shape[0]
assert input_padding_mask.shape[1] == input.shape[1]
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
@ -174,81 +179,120 @@ class MOELayer(Base):
d_model = input.shape[2]
# Pad to expected batch size
input_shape = list(input.shape)
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0)
expected_bsz = (
getattr(self.args, "batch_size", 0)
if self.training
else getattr(self.args, "batch_size_valid", 0)
)
# This indicates that --batch-size or --max-sentences is not specified
if expected_bsz is None:
expected_bsz = 0
# Note: Padding is not necessary at generation time at present
# because all DDP workers process the same batch. Also, batch size at generation time
# can be different from that present in the checkpoint state
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz:
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})")
if (
not self.in_generation
and expected_bsz != 0
and input_shape[0] != expected_bsz
):
logger.warning(
f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})"
)
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
padded_input = torch.zeros(
(expected_bsz, input_shape[1], input_shape[2]),
dtype=input.dtype, layout=input.layout, device=input.device)
padded_input[:input_shape[0], :, :] = input
dtype=input.dtype,
layout=input.layout,
device=input.device,
)
padded_input[: input_shape[0], :, :] = input
input = padded_input
padded_input_padding_mask = torch.ones(
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device
(
expected_bsz,
input_shape[1],
),
dtype=torch.bool,
device=input.device,
)
if input_padding_mask is not None:
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask
padded_input_padding_mask[: input_shape[0], :] = input_padding_mask
else:
padded_input_padding_mask[:input_shape[0], :] = False
padded_input_padding_mask[: input_shape[0], :] = False
input_padding_mask = padded_input_padding_mask
# Reshape into S tokens by dropping sequence dimension.
reshaped_input = input.reshape(-1, d_model)
reshaped_input_shape = reshaped_input.shape
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None
reshaped_input_padding_mask = (
input_padding_mask.reshape(-1) if input_padding_mask is not None else None
)
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
# Pro of --max-tokens: more flexible for MT variable sequence lengths
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
if expected_bsz == 0:
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device)
expected_dim = reshaped_input_shape[0] * torch.ones(
(1,), dtype=torch.long, device=input.device
)
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
expected_dim = int(expected_dim.item())
padded_input = torch.zeros(
(expected_dim, reshaped_input_shape[1]),
dtype=input.dtype, layout=input.layout, device=input.device)
padded_input[:reshaped_input_shape[0], :] = reshaped_input
dtype=input.dtype,
layout=input.layout,
device=input.device,
)
padded_input[: reshaped_input_shape[0], :] = reshaped_input
reshaped_input = padded_input
padded_input_padding_mask = torch.ones(
(expected_dim,), dtype=torch.bool, device=padded_input.device
)
if reshaped_input_padding_mask is not None:
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask
padded_input_padding_mask[
: reshaped_input_shape[0]
] = reshaped_input_padding_mask
else:
padded_input_padding_mask[:reshaped_input_shape[0]] = False
padded_input_padding_mask[: reshaped_input_shape[0]] = False
reshaped_input_padding_mask = padded_input_padding_mask
if has_tutel:
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask)
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(
reshaped_input, reshaped_input_padding_mask
)
S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
if not hasattr(self, "_tutel_dispatcher"):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E, C, M, dispatch_dtype=reshaped_input.dtype
)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask)
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(
reshaped_input, reshaped_input_padding_mask
)
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S
dispatch_mask = dispatch_mask.to(input.dtype).permute(
1, 2, 0
) # S,E,C -> E,C,S
E, C, S = dispatch_mask.size()
M = reshaped_input.size(1)
assert reshaped_input.size() == (S, M)
# einsum("sec,sm->ecm")
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M
dispatched_input = torch.mm(
dispatch_mask.view(E * C, S), reshaped_input
) # -> (E*C),M
if self.all2all_size > 1:
dispatched_input = self.all_to_all_wrapper(dispatched_input)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model)
dispatched_input = dispatched_input.reshape(
self.all2all_size, self.num_local_experts, -1, d_model
)
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
@ -259,18 +303,24 @@ class MOELayer(Base):
expert_output = self.all_to_all_wrapper(expert_output)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model)
expert_output = expert_output.reshape(
self.all2all_size * self.num_local_experts, -1, d_model
)
if has_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M))
combined_output = self._tutel_dispatcher.decode(
expert_output.view(E * C, M)
)
else:
# einsum("sec,ecm->sm")
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M))
combined_output = combine_weights.view(S, E * C).mm(
expert_output.view(E * C, M)
)
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
combined_output = combined_output[:reshaped_input_shape[0], :]
combined_output = combined_output[: reshaped_input_shape[0], :]
combined_output = combined_output.reshape(input.shape)
combined_output = combined_output[:input_shape[0], :, :]
combined_output = combined_output[: input_shape[0], :, :]
self.record_all_to_all_stats()
@ -280,7 +330,7 @@ class MOELayer(Base):
self.in_generation = True
def all_to_all_wrapper(self, input: Tensor):
dummy_a2a = getattr(self.args, 'dummy_a2a', False)
dummy_a2a = getattr(self.args, "dummy_a2a", False)
if dummy_a2a:
input = input.contiguous()
output = input.detach().clone()
@ -294,13 +344,13 @@ class MOELayer(Base):
output = _AllToAll.apply(self.all2all_group, input)
cuda_end.record()
cpu_end = time.time() * 1000
self.a2a_cpu_time_ms += (cpu_end - cpu_start)
self.a2a_cpu_time_ms += cpu_end - cpu_start
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
return output
def record_all_to_all_stats(self):
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False)
record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False)
if record_a2a_perf_stats:
torch.cuda.synchronize()
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms

@ -13,14 +13,14 @@
# NOTE: This is a mirror of the code in
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
from typing import Callable, Dict, Tuple, Optional
import math
from typing import Callable, Dict, Optional, Tuple
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import Tensor
from .moe_layer import has_tutel, fused_cumsum_sub_one
from .moe_layer import fused_cumsum_sub_one, has_tutel
# use a fixed temperature to compute balance loss
TEMPERATURE_FOR_L_UAX = 0.07
@ -65,13 +65,22 @@ def top1gating(
indices1_s = torch.argmax(gates, dim=1)
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
if input_mask is not None and input_mask.any():
nonpadding = ~ input_mask
nonpadding = ~input_mask
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
# for logging (percent of tokens routed to each expert)
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
expert1_hist = (
100
* torch.histc(
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
)
/ num_tokens
)
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
expert1_hist = (
torch.sort(expert1_hist, dim=0, descending=True).values
+ torch.finfo(torch.float32).tiny
)
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
@ -91,7 +100,21 @@ def top1gating(
if has_tutel:
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ]
return (
l_aux,
metadata,
capacity,
num_experts,
[
indices1_s,
],
[
locations1_s,
],
[
gates1_s,
],
)
# Remove locations outside capacity from mask
mask1 = mask1 * torch.lt(locations1, capacity)
@ -104,7 +127,8 @@ def top1gating(
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
combine1_sec = torch.bmm(
# einsum("se,sc->sec")
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
gates1.unsqueeze(-1),
locations1_sc.to(gates1.dtype).unsqueeze(1),
)
dispatch_mask = combine1_sec.bool()
if use_fp32:
@ -218,10 +242,10 @@ def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) ->
if unsqueeze_indices:
indices = indices.unsqueeze(-1)
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype)
output.scatter_(
len(output.shape) - 1, indices, 1
output = torch.zeros(
indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype
)
output.scatter_(len(output.shape) - 1, indices, 1)
return output
@ -235,7 +259,7 @@ def top2gating(
logits: torch.Tensor,
input_mask: Optional[torch.Tensor] = None,
use_fp32=False,
second_expert_policy='sampling',
second_expert_policy="sampling",
normalize_gate_prob_before_dropping=False,
eval_mode=False,
moe_eval_capacity_token_fraction=0.25,
@ -260,7 +284,7 @@ def top2gating(
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
mask1 = one_hot(indices1_s, num_experts)
if second_expert_policy == 'sampling':
if second_expert_policy == "sampling":
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
@ -281,13 +305,13 @@ def top2gating(
gates1_s = gates1_s / denom_s
gates2_s = gates2_s / denom_s
if second_expert_policy == 'random':
if second_expert_policy == "random":
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
# Compute locations in capacity buffer
if input_mask is not None and input_mask.any():
nonpadding = ~ input_mask
nonpadding = ~input_mask
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
@ -296,15 +320,22 @@ def top2gating(
importance_scores = -1 * gates.max(dim=1)[0]
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
importance_sorted_locations1 = sorted_cumsum1[
importance_scores.argsort(dim=0).argsort(dim=0)
]
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
importance_sorted_locations2 = sorted_cumsum2[
importance_scores.argsort(dim=0).argsort(dim=0)
]
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2
locations1, locations2 = (
importance_sorted_locations1,
importance_sorted_locations2,
)
else:
locations1 = fused_cumsum_sub_one(mask1)
locations2 = fused_cumsum_sub_one(mask2)
@ -318,8 +349,12 @@ def top2gating(
l_aux = l_aux * num_experts * num_experts
# for logging purposes
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
metadata["overflow_expert1"] = (
100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
)
metadata["overflow_expert2"] = (
100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
)
# Remove locations outside capacity from mask
mask1_, mask2_ = mask1, mask2
@ -327,13 +362,31 @@ def top2gating(
mask2 = mask2 * torch.lt(locations2, capacity)
# for logging (percent of tokens routed to each expert)
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
expert1_hist = (
100
* torch.histc(
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
)
/ num_tokens
)
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
expert1_hist = (
torch.sort(expert1_hist, dim=0, descending=True).values
+ torch.finfo(torch.float32).tiny
)
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
expert2_hist = (
100
* torch.histc(
(indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
)
/ num_tokens
)
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
expert2_hist = (
torch.sort(expert2_hist, dim=0, descending=True).values
+ torch.finfo(torch.float32).tiny
)
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
@ -355,8 +408,15 @@ def top2gating(
if has_tutel:
locations1_s = torch.sum(locations1 * mask1_, dim=1)
locations2_s = torch.sum(locations2 * mask2_, dim=1)
return l_aux, metadata, capacity, num_experts, \
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
return (
l_aux,
metadata,
capacity,
num_experts,
[indices1_s, indices2_s],
[locations1_s, locations2_s],
[gates1_s, gates2_s],
)
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
@ -369,11 +429,13 @@ def top2gating(
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
combine1_sec = torch.bmm(
# einsum("se,sc->sec")
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
gates1.unsqueeze(-1),
locations1_sc.to(gates1.dtype).unsqueeze(1),
)
combine2_sec = torch.bmm(
# einsum("se,sc->sec")
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1)
gates2.unsqueeze(-1),
locations2_sc.to(gates2.dtype).unsqueeze(1),
)
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
@ -406,7 +468,7 @@ class Top2Gate(torch.nn.Module):
model_dim: int,
num_experts: int,
use_fp32=False,
second_expert_policy='sampling',
second_expert_policy="sampling",
normalize_gate_prob_before_dropping=False,
moe_eval_capacity_token_fraction=0.25,
batch_prioritized_routing=False,

@ -3,37 +3,35 @@
import torch
import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding
from torchscale.component.embedding import (
PositionalEmbedding,
TextEmbedding,
VisionEmbedding,
)
from torchscale.component.multiway_network import MultiwayWrapper
class BEiT3(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
self.args = args
assert args.multiway
assert args.vocab_size > 0
assert not args.share_encoder_input_output_embed
self.text_embed = TextEmbedding(
args.vocab_size,
args.encoder_embed_dim
)
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
self.vision_embed = VisionEmbedding(
args.img_size,
args.patch_size,
args.in_chans,
args.encoder_embed_dim,
contain_mask_token=True,
prepend_cls_token=True
prepend_cls_token=True,
)
embed_positions = MultiwayWrapper(
args,
PositionalEmbedding(
args.max_source_positions,
args.encoder_embed_dim
),
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
dim=1,
)
self.encoder = Encoder(
@ -71,7 +69,7 @@ class BEiT3(nn.Module):
encoder_padding_mask = torch.cat(
[
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
text_padding_position
text_padding_position,
],
dim=1,
)