Code reformatting

This commit is contained in:
shumingma 2022-11-26 09:01:02 -08:00
parent 1354614d44
commit 7eca1a531c
29 changed files with 781 additions and 563 deletions

View File

@ -4,7 +4,6 @@
# flake8: noqa # flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.generate import cli_main from fairseq_cli.generate import cli_main
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -4,7 +4,6 @@
# flake8: noqa # flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.interactive import cli_main from fairseq_cli.interactive import cli_main
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,24 +2,24 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import logging import logging
from typing import Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
)
from fairseq.modules import PositionalEmbedding
from fairseq.models.squad import SQuADHead
from omegaconf import II
from .machine_translation import MTEncoder as Encoder
from torchscale.architecture.config import EncoderConfig
from apex.normalization import FusedLayerNorm as LayerNorm from apex.normalization import FusedLayerNorm as LayerNorm
from fairseq import utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
from fairseq.models.squad import SQuADHead
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from fairseq.modules import PositionalEmbedding
from omegaconf import II
from torchscale.architecture.config import EncoderConfig
from .machine_translation import MTEncoder as Encoder
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
@ -109,7 +109,7 @@ class BertConfig(FairseqDataclass):
"is set to 0 (i.e., always wrap) when --checkpoint-activations or " "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed." "--offload-activations are passed."
) )
} },
) )
max_source_positions: int = field( max_source_positions: int = field(
default=1024, metadata={"help": "max source positions"} default=1024, metadata={"help": "max source positions"}
@ -118,59 +118,41 @@ class BertConfig(FairseqDataclass):
default="relu", metadata={"help": "activation function to use for pooler layer"} default="relu", metadata={"help": "activation function to use for pooler layer"}
) )
pooler_dropout: float = field( pooler_dropout: float = field(
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"} default=0.0,
metadata={"help": "dropout probability in the masked_lm pooler layers"},
) )
# options from other parts of the config # options from other parts of the config
# add_bos_token: bool = II("task.add_bos_token") # add_bos_token: bool = II("task.add_bos_token")
# tokens_per_sample: int = II("task.tokens_per_sample") # tokens_per_sample: int = II("task.tokens_per_sample")
tpu: bool = II("common.tpu") tpu: bool = II("common.tpu")
rel_pos_buckets: int = field( rel_pos_buckets: int = field(default=0, metadata={"help": ""})
default=0, metadata={"help": ""} max_rel_pos: int = field(default=0, metadata={"help": ""})
)
max_rel_pos: int = field(
default=0, metadata={"help": ""}
)
moe_freq: int = field( moe_freq: int = field(
default=0, default=0,
metadata={ metadata={"help": "Frequency at which we insert MoE Transformer layers"},
"help": "Frequency at which we insert MoE Transformer layers"
},
) )
moe_expert_count: int = field( moe_expert_count: int = field(
default=0, default=0, metadata={"help": "Number of experts in each MoE Layer"}
metadata={
"help": "Number of experts in each MoE Layer"
}
) )
moe_gating_use_fp32: bool = field( moe_gating_use_fp32: bool = field(
default=False, default=False,
metadata={ metadata={"help": "Use FP32 computations in MoE top2 gating function"},
"help": "Use FP32 computations in MoE top2 gating function"
}
) )
moe_second_expert_policy: str = field( moe_second_expert_policy: str = field(
default='sampling', default="sampling",
metadata={ metadata={"help": "policy for second expert, options: all/sampling/random"},
"help": "policy for second expert, options: all/sampling/random"
}
) )
moe_normalize_gate_prob_before_dropping: bool = field( moe_normalize_gate_prob_before_dropping: bool = field(
default=False, default=False,
metadata={ metadata={
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization' "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
} },
) )
moe_expert_ffn_dim: Optional[int] = field( moe_expert_ffn_dim: Optional[int] = field(
default=None, default=None, metadata={"help": "MoE expert FFN dimension"}
metadata={
"help": "MoE expert FFN dimension"
}
) )
moe_top1_expert: Optional[bool] = field( moe_top1_expert: Optional[bool] = field(
default=False, default=False, metadata={"help": "Use top1 gate instead of top2"}
metadata={
"help": "Use top1 gate instead of top2"
}
) )
moe_eval_capacity_token_fraction: Optional[float] = field( moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25, default=0.25,
@ -179,23 +161,29 @@ class BertConfig(FairseqDataclass):
"Default: 0.25, Fraction of tokens as capacity during validation, " "Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]." "if set to negative, use same as training. range: (0.0, 1.0]."
) )
} },
) )
moe_normalize_expert_grad: Optional[str] = field( moe_normalize_expert_grad: Optional[str] = field(
default='world_size', default="world_size",
metadata={ metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
} },
) )
record_a2a_perf_stats: Optional[bool] = field( record_a2a_perf_stats: Optional[bool] = field(
default=False, metadata={"help": "records all to all perf stats during distributed training"} default=False,
metadata={"help": "records all to all perf stats during distributed training"},
) )
dummy_a2a: Optional[bool] = field( dummy_a2a: Optional[bool] = field(
default=False, metadata={ default=False,
"help": "By passes all to all during distributed training by returning the input buffer as output"} metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
) )
moe_batch_prioritized_routing: Optional[bool] = field( moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
) )
ddp_rank: int = II("distributed_training.distributed_rank") ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field( deepnorm: Optional[bool] = field(
@ -208,7 +196,6 @@ class BertConfig(FairseqDataclass):
@register_model("mlm", dataclass=BertConfig) @register_model("mlm", dataclass=BertConfig)
class BertModel(BaseFairseqModel): class BertModel(BaseFairseqModel):
def __init__(self, args, encoder): def __init__(self, args, encoder):
super().__init__() super().__init__()
self.args = args self.args = args
@ -240,7 +227,11 @@ class BertModel(BaseFairseqModel):
) )
lm_head = cls.build_lm_head( lm_head = cls.build_lm_head(
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight args,
args.encoder_embed_dim,
len(task.dictionary),
args.activation_fn,
weight=embed_tokens.weight,
) )
config = EncoderConfig() config = EncoderConfig()
@ -269,7 +260,9 @@ class BertModel(BaseFairseqModel):
def output_layer(self, features, masked_tokens=None): def output_layer(self, features, masked_tokens=None):
return self.encoder.output_projection(features, masked_tokens=masked_tokens) return self.encoder.output_projection(features, masked_tokens=masked_tokens)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head.""" """Register a classification head."""
if name in self.classification_heads: if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features prev_num_classes = self.classification_heads[name].out_proj.out_features
@ -277,7 +270,7 @@ class BertModel(BaseFairseqModel):
if num_classes != prev_num_classes or inner_dim != prev_inner_dim: if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning( logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) ' 're-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format( "and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
) )
) )
@ -295,42 +288,51 @@ class BertModel(BaseFairseqModel):
) )
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else '' prefix = name + "." if name != "" else ""
# upgrade children modules # upgrade children modules
super().upgrade_state_dict_named(state_dict, name) super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict. # Handle new classification heads present in the state dict.
current_head_names = ( current_head_names = (
[] if not hasattr(self, 'classification_heads') []
if not hasattr(self, "classification_heads")
else self.classification_heads.keys() else self.classification_heads.keys()
) )
keys_to_delete = [] keys_to_delete = []
for k in state_dict.keys(): for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'): if not k.startswith(prefix + "classification_heads."):
continue continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0) num_classes = state_dict[
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0) prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if getattr(self.args, 'load_checkpoint_heads', False): if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names: if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim) self.register_classification_head(head_name, num_classes, inner_dim)
else: else:
if head_name not in current_head_names: if head_name not in current_head_names:
logger.warning( logger.warning(
'deleting classification head ({}) from checkpoint ' "deleting classification head ({}) from checkpoint "
'not present in current model: {}'.format(head_name, k) "not present in current model: {}".format(head_name, k)
) )
keys_to_delete.append(k) keys_to_delete.append(k)
elif ( elif (
num_classes != self.classification_heads[head_name].out_proj.out_features num_classes
or inner_dim != self.classification_heads[head_name].dense.out_features != self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
): ):
logger.warning( logger.warning(
'deleting classification head ({}) from checkpoint ' "deleting classification head ({}) from checkpoint "
'with different dimensions than current model: {}'.format(head_name, k) "with different dimensions than current model: {}".format(
head_name, k
)
) )
keys_to_delete.append(k) keys_to_delete.append(k)
for k in keys_to_delete: for k in keys_to_delete:
@ -338,12 +340,12 @@ class BertModel(BaseFairseqModel):
# Copy any newly-added classification heads into the state dict # Copy any newly-added classification heads into the state dict
# with their current weights. # with their current weights.
if hasattr(self, 'classification_heads'): if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict() cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items(): for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict: if prefix + "classification_heads." + k not in state_dict:
logger.info('Overwriting ' + prefix + 'classification_heads.' + k) logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + 'classification_heads.' + k] = v state_dict[prefix + "classification_heads." + k] = v
def forward( def forward(
self, self,
@ -354,7 +356,9 @@ class BertModel(BaseFairseqModel):
masked_tokens=None, masked_tokens=None,
**kwargs **kwargs
): ):
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens) encoder_out = self.encoder(
src_tokens, features_only=True, return_all_hiddens=return_all_hiddens
)
x, extra = encoder_out["encoder_out"], encoder_out x, extra = encoder_out["encoder_out"], encoder_out
x = x.transpose(0, 1) x = x.transpose(0, 1)
@ -455,7 +459,7 @@ def base_unilm_architecture(args):
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim) args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
# Model training is not stable without this # Model training is not stable without this
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False) args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", True) args.no_scale_embedding = getattr(args, "no_scale_embedding", True)

View File

@ -9,10 +9,9 @@
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional
import torch
from fairseq import utils import torch
from fairseq import distributed_utils from fairseq import distributed_utils, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import ( from fairseq.models import (
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
@ -20,14 +19,13 @@ from fairseq.models import (
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
from fairseq.models.transformer import ( from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
)
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from torchscale.architecture.decoder import Decoder
from torchscale.architecture.config import DecoderConfig
from omegaconf import II from omegaconf import II
from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
DEFAULT_MAX_TARGET_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -104,49 +102,34 @@ class LanguageConfig(FairseqDataclass):
"is set to 0 (i.e., always wrap) when --checkpoint-activations or " "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed." "--offload-activations are passed."
) )
} },
) )
moe_freq: int = field( moe_freq: int = field(
default=0, default=0,
metadata={ metadata={"help": "Frequency at which we insert MoE Transformer layers"},
"help": "Frequency at which we insert MoE Transformer layers"
},
) )
moe_expert_count: int = field( moe_expert_count: int = field(
default=0, default=0, metadata={"help": "Number of experts in each MoE Layer"}
metadata={
"help": "Number of experts in each MoE Layer"
}
) )
moe_gating_use_fp32: bool = field( moe_gating_use_fp32: bool = field(
default=False, default=False,
metadata={ metadata={"help": "Use FP32 computations in MoE top2 gating function"},
"help": "Use FP32 computations in MoE top2 gating function"
}
) )
moe_second_expert_policy: str = field( moe_second_expert_policy: str = field(
default='sampling', default="sampling",
metadata={ metadata={"help": "policy for second expert, options: all/sampling/random"},
"help": "policy for second expert, options: all/sampling/random"
}
) )
moe_normalize_gate_prob_before_dropping: bool = field( moe_normalize_gate_prob_before_dropping: bool = field(
default=False, default=False,
metadata={ metadata={
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization' "help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
} },
) )
moe_expert_ffn_dim: Optional[int] = field( moe_expert_ffn_dim: Optional[int] = field(
default=None, default=None, metadata={"help": "MoE expert FFN dimension"}
metadata={
"help": "MoE expert FFN dimension"
}
) )
moe_top1_expert: Optional[bool] = field( moe_top1_expert: Optional[bool] = field(
default=False, default=False, metadata={"help": "Use top1 gate instead of top2"}
metadata={
"help": "Use top1 gate instead of top2"
}
) )
moe_eval_capacity_token_fraction: Optional[float] = field( moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25, default=0.25,
@ -155,23 +138,29 @@ class LanguageConfig(FairseqDataclass):
"Default: 0.25, Fraction of tokens as capacity during validation, " "Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]." "if set to negative, use same as training. range: (0.0, 1.0]."
) )
} },
) )
moe_normalize_expert_grad: Optional[str] = field( moe_normalize_expert_grad: Optional[str] = field(
default='world_size', default="world_size",
metadata={ metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'" "help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
} },
) )
record_a2a_perf_stats: Optional[bool] = field( record_a2a_perf_stats: Optional[bool] = field(
default=False, metadata={"help": "records all to all perf stats during distributed training"} default=False,
metadata={"help": "records all to all perf stats during distributed training"},
) )
dummy_a2a: Optional[bool] = field( dummy_a2a: Optional[bool] = field(
default=False, metadata={ default=False,
"help": "By passes all to all during distributed training by returning the input buffer as output"} metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
) )
moe_batch_prioritized_routing: Optional[bool] = field( moe_batch_prioritized_routing: Optional[bool] = field(
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."} default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
) )
use_xmoe: Optional[bool] = field( use_xmoe: Optional[bool] = field(
default=False, default=False,
@ -205,7 +194,6 @@ class LanguageConfig(FairseqDataclass):
@register_model("lm", dataclass=LanguageConfig) @register_model("lm", dataclass=LanguageConfig)
class LanguageModel(FairseqLanguageModel): class LanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder): def __init__(self, args, decoder):
self.args = args self.args = args
super().__init__(decoder) super().__init__(decoder)
@ -245,19 +233,17 @@ class LanguageModel(FairseqLanguageModel):
args.decoder_embed_dim, len(task.dictionary), bias=False args.decoder_embed_dim, len(task.dictionary), bias=False
) )
torch.nn.init.normal_( torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
) )
if ( if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, 'moe_freq', 0) > 0 getattr(args, "fp16", False)
and ( and not getattr(args, "memory_efficient_fp16", False)
getattr(args, 'fp16', False) and getattr(args, "ddp_backend", None) != "fully_sharded"
and not getattr(args, 'memory_efficient_fp16', False)
and getattr(args, 'ddp_backend', None) != "fully_sharded"
)
): ):
assert args.fp16_no_flatten_grads, \ assert (
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm" args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank() args.ddp_rank = distributed_utils.get_data_parallel_rank()
@ -281,7 +267,6 @@ class LanguageModel(FairseqLanguageModel):
class LMDecoder(Decoder, FairseqIncrementalDecoder): class LMDecoder(Decoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs): def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens, self_attn_padding_mask, **kwargs) return super().forward(src_tokens, self_attn_padding_mask, **kwargs)

View File

@ -6,12 +6,12 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
from fairseq import utils from fairseq import distributed_utils, utils
from fairseq.distributed import utils as fsdp_wrap from fairseq.distributed import utils as fsdp_wrap
from fairseq import distributed_utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel, FairseqEncoderDecoderModel,
@ -20,12 +20,13 @@ from fairseq.models import (
) )
from fairseq.models.transformer import Embedding from fairseq.models.transformer import Embedding
from fairseq.modules import PositionalEmbedding from fairseq.modules import PositionalEmbedding
from torch import Tensor
from torchscale.architecture.config import DecoderConfig, EncoderConfig
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
from torchscale.architecture.config import EncoderConfig, DecoderConfig
from .language_modeling import LMDecoder as MTDecoder from .language_modeling import LMDecoder as MTDecoder
from torch import Tensor
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_SOURCE_POSITIONS = 1024
@ -35,7 +36,6 @@ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("mt") @register_model("mt")
class TranslationModel(FairseqEncoderDecoderModel): class TranslationModel(FairseqEncoderDecoderModel):
def __init__(self, args, encoder, decoder): def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.args = args self.args = args
@ -269,7 +269,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
args.decoder_embed_dim, len(tgt_dict), bias=False args.decoder_embed_dim, len(tgt_dict), bias=False
) )
torch.nn.init.normal_( torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
) )
encoder = cls.build_encoder( encoder = cls.build_encoder(
@ -320,7 +320,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
) )
@classmethod @classmethod
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary): def build_decoder(
cls, args, embed_tokens, embed_positions, output_projection, dictionary
):
config = DecoderConfig() config = DecoderConfig()
config.override(args) config.override(args)
@ -342,10 +344,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
features_only: bool = False, features_only: bool = False,
**kwargs **kwargs
): ):
encoder_out = self.encoder( encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
src_tokens,
return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder( decoder_out = self.decoder(
prev_output_tokens, prev_output_tokens,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -365,15 +364,20 @@ class TranslationModel(FairseqEncoderDecoderModel):
class MTEncoder(Encoder, FairseqEncoder): class MTEncoder(Encoder, FairseqEncoder):
def forward(self, src_tokens, **kwargs): def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad()) self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs) return super().forward(
src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs
)
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order) new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order) new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order) 0, new_order
)
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(
0, new_order
)
encoder_states = encoder_out["encoder_states"] encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0: if len(encoder_states) > 0:

View File

@ -3,6 +3,7 @@
import torch import torch
from infinibatch.iterators import CheckpointableIterator from infinibatch.iterators import CheckpointableIterator
from . import utils from . import utils
@ -25,7 +26,6 @@ class BaseBatchGen(CheckpointableIterator):
raise NotImplementedError() raise NotImplementedError()
def _move_to_tensor(self, batch): def _move_to_tensor(self, batch):
def to_tensor(x): def to_tensor(x):
return torch.tensor(x) return torch.tensor(x)

View File

@ -1,18 +1,18 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import os
import numpy as np
import itertools
import copy import copy
import itertools
import os
import numpy as np
from infinibatch import iterators from infinibatch import iterators
from .basic_loader import BaseBatchGen from .basic_loader import BaseBatchGen
from .utils import NativeCheckpointableIterator, WeightIterator from .utils import NativeCheckpointableIterator, WeightIterator
class MLMLoader(BaseBatchGen): class MLMLoader(BaseBatchGen):
def __init__( def __init__(
self, self,
args, args,
@ -62,9 +62,7 @@ class MLMLoader(BaseBatchGen):
log_empty_buffer_warning=True and self.shard_id == 0, log_empty_buffer_warning=True and self.shard_id == 0,
) )
prefetch_batches = iterators.MapIterator( prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)
prefetch_batches, self._move_to_tensor
)
self._iter = prefetch_batches self._iter = prefetch_batches
@ -73,25 +71,25 @@ class MLMLoader(BaseBatchGen):
weights = [] weights = []
for data in self.data: for data in self.data:
multilingual_iters.append( multilingual_iters.append(self._tokenize(data))
self._tokenize(data) if "weight" in data:
) weights.append(float(data["weight"]))
if 'weight' in data:
weights.append(float(data['weight']))
else: else:
weights.append(int(data['count'])) weights.append(int(data["count"]))
if len(multilingual_iters) == 1: if len(multilingual_iters) == 1:
return multilingual_iters[0] return multilingual_iters[0]
sampling_iterator = WeightIterator(weights) sampling_iterator = WeightIterator(weights)
control_iterator = NativeCheckpointableIterator(sampling_iterator) control_iterator = NativeCheckpointableIterator(sampling_iterator)
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters) tokenized_lines = iterators.MultiplexIterator(
control_iterator, multilingual_iters
)
return tokenized_lines return tokenized_lines
def _tokenize(self, data): def _tokenize(self, data):
''' """
data: data:
{ {
'source': list[Path], 'source': list[Path],
@ -100,17 +98,16 @@ class MLMLoader(BaseBatchGen):
'weight': float, 'weight': float,
'name': str, 'name': str,
} }
''' """
dataset = list( dataset = list(
zip( zip(
data['source'], data["source"],
itertools.repeat(data['source_lang']), itertools.repeat(data["source_lang"]),
) )
) )
if self.shuffle: if self.shuffle:
chunk_files = \ chunk_files = iterators.InfinitePermutationSourceIterator(
iterators.InfinitePermutationSourceIterator(
dataset, dataset,
seed=self.seed, seed=self.seed,
shuffle=self.shuffle, shuffle=self.shuffle,
@ -118,15 +115,18 @@ class MLMLoader(BaseBatchGen):
instance_rank=self.shard_id, instance_rank=self.shard_id,
) )
else: else:
chunk_files = \ chunk_files = iterators.ChunkedSourceIterator(
iterators.ChunkedSourceIterator(
dataset, dataset,
num_instances=self.num_shards, num_instances=self.num_shards,
instance_rank=self.shard_id, instance_rank=self.shard_id,
) )
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files)) tokenized_lines = iterators.SelectManyIterator(
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed) chunk_files, lambda files: self._read_from_files(*files)
)
tokenized_lines = iterators.SamplingRandomMapIterator(
tokenized_lines, self._prepare, self.seed
)
return tokenized_lines return tokenized_lines
@ -134,13 +134,20 @@ class MLMLoader(BaseBatchGen):
if self.max_sentences is not None: if self.max_sentences is not None:
if self.batch_read_ahead > 0: if self.batch_read_ahead > 0:
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed) lines = iterators.BlockwiseShuffleIterator(
lines, self.batch_read_ahead, self.seed
)
batches = iterators.FixedBatchIterator(lines, self.max_sentences) batches = iterators.FixedBatchIterator(lines, self.max_sentences)
else: else:
def dynamic_batch_size(sample): def dynamic_batch_size(sample):
lengths = [len(x) for x in sample] lengths = [len(x) for x in sample]
batch_size = self.max_tokens // max(lengths) batch_size = self.max_tokens // max(lengths)
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple batch_size = (
batch_size
// self.required_batch_size_multiple
* self.required_batch_size_multiple
)
return max(1, batch_size) return max(1, batch_size)
batches = iterators.BucketedReadaheadBatchIterator( batches = iterators.BucketedReadaheadBatchIterator(
@ -160,38 +167,56 @@ class MLMLoader(BaseBatchGen):
s2s_source_max_length = max([len(x[2]) for x in batch]) s2s_source_max_length = max([len(x[2]) for x in batch])
s2s_target_max_length = max([len(x[3]) for x in batch]) s2s_target_max_length = max([len(x[3]) for x in batch])
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32, mlm_source_ids = np.full(
fill_value=self.dictionary.pad()) shape=(batch_size, mlm_source_max_length),
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32, dtype=np.int32,
fill_value=self.dictionary.pad()) fill_value=self.dictionary.pad(),
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32, )
fill_value=self.dictionary.pad()) mlm_target_ids = np.full(
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, shape=(batch_size, mlm_target_max_length),
fill_value=self.dictionary.pad()) dtype=np.int32,
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32, fill_value=self.dictionary.pad(),
fill_value=self.dictionary.pad()) )
s2s_source_ids = np.full(
shape=(batch_size, s2s_source_max_length),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
s2s_target_ids = np.full(
shape=(batch_size, s2s_target_max_length - 1),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
s2s_prev_input_ids = np.full(
shape=(batch_size, s2s_target_max_length - 1),
dtype=np.int32,
fill_value=self.dictionary.pad(),
)
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch): for i, (
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids mlm_input_ids,
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids mlm_label_ids,
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids s2s_input_ids,
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:] s2s_label_ids,
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1] ) in enumerate(batch):
mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]
ret_batch = { ret_batch = {
'net_input': { "net_input": {
'src_tokens': mlm_source_ids.astype(np.int64), "src_tokens": mlm_source_ids.astype(np.int64),
}, },
'target': mlm_target_ids.astype(np.int64), "target": mlm_target_ids.astype(np.int64),
'nsentences': batch_size, "nsentences": batch_size,
'ntokens': sum([len(x[0]) for x in batch]), "ntokens": sum([len(x[0]) for x in batch]),
} }
return ret_batch return ret_batch
padded_batches = iterators.MapIterator( padded_batches = iterators.MapIterator(batches, collate)
batches, collate
)
return padded_batches return padded_batches
@ -221,7 +246,6 @@ class MLMLoader(BaseBatchGen):
return nonmasked_tokens, masked_tokens return nonmasked_tokens, masked_tokens
def _span_corruption(self, _random, doc): def _span_corruption(self, _random, doc):
def mask_tokens(i): def mask_tokens(i):
return f"<mask_{i}>" return f"<mask_{i}>"
@ -237,7 +261,9 @@ class MLMLoader(BaseBatchGen):
else: else:
possible_split_positions = list(range(1, noise_tokens_num)) possible_split_positions = list(range(1, noise_tokens_num))
_random.shuffle(possible_split_positions) _random.shuffle(possible_split_positions)
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1]) noise_split_positions = sorted(
possible_split_positions[: noise_spans_num - 1]
)
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num] noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
possible_insert_positions = list(range(nonnoise_tokens_num)) possible_insert_positions = list(range(nonnoise_tokens_num))
@ -248,7 +274,7 @@ class MLMLoader(BaseBatchGen):
last_end = 0 last_end = 0
for i in range(noise_spans_num): for i in range(noise_spans_num):
start_pos = noise_insert_positions[i] + noise_split_positions[i] start_pos = noise_insert_positions[i] + noise_split_positions[i]
end_pos = noise_insert_positions[i] + noise_split_positions[i+1] end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
mask_id = self.dictionary.indices[mask_tokens(i)] mask_id = self.dictionary.indices[mask_tokens(i)]
if getattr(self.args, "remove_target_sentinel", False): if getattr(self.args, "remove_target_sentinel", False):
@ -273,23 +299,25 @@ class MLMLoader(BaseBatchGen):
file_path = os.path.join(self.data_dir, source_file) file_path = os.path.join(self.data_dir, source_file)
if not os.path.exists(file_path): if not os.path.exists(file_path):
print('| file {} not exists'.format(file_path), flush=True) print("| file {} not exists".format(file_path), flush=True)
return iter([]) # skip bad file return iter([]) # skip bad file
with open(file_path, 'r', encoding='utf8') as f: with open(file_path, "r", encoding="utf8") as f:
lines = f.read().strip().split('\n') lines = f.read().strip().split("\n")
doc = [self.dictionary.bos()] doc = [self.dictionary.bos()]
for line in lines: for line in lines:
if line == "": if line == "":
if self.sample_break_mode == 'complete_doc': if self.sample_break_mode == "complete_doc":
# data.append(doc) # data.append(doc)
yield doc yield doc
doc = [self.dictionary.bos()] doc = [self.dictionary.bos()]
continue continue
tokenized_line = self.tokenizer.EncodeAsPieces(line) tokenized_line = self.tokenizer.EncodeAsPieces(line)
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index] tokenized_id = [
self.dictionary.index(token) for token in tokenized_line
] + [self.dictionary.eos_index]
if len(tokenized_id) > self.tokens_per_sample: if len(tokenized_id) > self.tokens_per_sample:
continue continue

View File

@ -1,10 +1,11 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import numpy as np import collections
from random import Random from random import Random
from typing import Dict, Iterable, Optional from typing import Dict, Iterable, Optional
import collections
import numpy as np
from infinibatch import iterators from infinibatch import iterators
@ -17,7 +18,9 @@ def apply_to_sample(f, sample):
return f(x) return f(x)
elif isinstance(x, collections.OrderedDict): elif isinstance(x, collections.OrderedDict):
# OrderedDict has attributes that needs to be preserved # OrderedDict has attributes that needs to be preserved
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items()) od = collections.OrderedDict(
(key, _apply(value)) for key, value in x.items()
)
od.__dict__ = x.__dict__ od.__dict__ = x.__dict__
return od return od
elif isinstance(x, dict): elif isinstance(x, dict):
@ -40,14 +43,15 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
self.setstate(None) self.setstate(None)
def getstate(self) -> Dict: def getstate(self) -> Dict:
return {'num_items_yielded': self._num_items_yielded} return {"num_items_yielded": self._num_items_yielded}
def setstate(self, checkpoint: Optional[Dict]): def setstate(self, checkpoint: Optional[Dict]):
self._iterator = iter(self._input_iterable) self._iterator = iter(self._input_iterable)
self._num_items_yielded = iterators._advance_iterator( self._num_items_yielded = (
self._iterator, iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
checkpoint['num_items_yielded'] if checkpoint is not None
) if checkpoint is not None else 0 else 0
)
def __next__(self): def __next__(self):
item = next(self._iterator) item = next(self._iterator)
@ -73,7 +77,9 @@ class WeightIterator(object):
def setstate(self, checkpoint): def setstate(self, checkpoint):
self._random_state = checkpoint["random_state"] if checkpoint else None self._random_state = checkpoint["random_state"] if checkpoint else None
self._random = None # this will trigger the lazy initialization in self.__next__ self._random = (
None # this will trigger the lazy initialization in self.__next__
)
def __next__(self): def __next__(self):
if self._random is None: if self._random is None:

View File

@ -1,23 +1,25 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import json
import logging
import os
from argparse import Namespace
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging
import os
from argparse import Namespace
import json
from omegaconf import MISSING, II
import sentencepiece as spm
from fairseq import utils from fairseq import utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import FairseqTask, register_task
from omegaconf import II, MISSING
from .data.mlm_loader import MLMLoader from .data.mlm_loader import MLMLoader
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
import sentencepiece as spm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -109,21 +111,16 @@ class PretrainingConfig(FairseqDataclass):
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple") required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
spm_model: str = field( spm_model: str = field(
default="", default="",
metadata={ metadata={"help": "sentencepice model to tokenize the data"},
"help": "sentencepice model to tokenize the data"
},
) )
dict_file: str = field( dict_file: str = field(
default="", default="",
metadata={ metadata={"help": ""},
"help": ""
},
) )
@register_task("pretraining", dataclass=PretrainingConfig) @register_task("pretraining", dataclass=PretrainingConfig)
class PLMTask(FairseqTask): class PLMTask(FairseqTask):
def __init__(self, cfg, dictionary, tokenizer): def __init__(self, cfg, dictionary, tokenizer):
super().__init__(cfg) super().__init__(cfg)
self.cfg = cfg self.cfg = cfg
@ -156,9 +153,9 @@ class PLMTask(FairseqTask):
def load_dataset(self, split, epoch=1, combine=False, **kwargs): def load_dataset(self, split, epoch=1, combine=False, **kwargs):
self.datasets[split] = { self.datasets[split] = {
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')), "data": json.load(open(f"{self.cfg.data}/json/{split}.json")),
'data_dir': self.cfg.data, "data_dir": self.cfg.data,
'shuffle': True if split == 'train' else False, "shuffle": True if split == "train" else False,
} }
self.datasets[split] = Namespace(**self.datasets[split]) self.datasets[split] = Namespace(**self.datasets[split])

View File

@ -4,7 +4,6 @@
# flake8: noqa # flake8: noqa
import models import models
import tasks import tasks
from fairseq_cli.train import cli_main from fairseq_cli.train import cli_main
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,17 +1,21 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import torch
import warnings
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
import torch.distributed as dist
import math import math
import warnings
import torch
import torch.distributed as dist
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
@torch.no_grad() @torch.no_grad()
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor: def clip_grad_norm_(
params, max_norm, moe_expert_count, aggregate_norm_fn=None
) -> torch.Tensor:
def grad_exists(p): def grad_exists(p):
return p is not None and getattr(p, "grad", None) is not None return p is not None and getattr(p, "grad", None) is not None
if isinstance(params, torch.Tensor): if isinstance(params, torch.Tensor):
params = [params] params = [params]
params = list(params) params = list(params)
@ -59,7 +63,9 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
for split_grads in [expert_grads, sharded_grads]: for split_grads in [expert_grads, sharded_grads]:
if len(split_grads) == 0: if len(split_grads) == 0:
continue continue
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])) split_norm = torch.norm(
torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
)
if dist.is_initialized(): if dist.is_initialized():
split_norm.pow_(2) split_norm.pow_(2)
dist.all_reduce(split_norm) dist.all_reduce(split_norm)

View File

@ -2,6 +2,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
from io import open from io import open
from setuptools import find_packages, setup from setuptools import find_packages, setup
setup( setup(
@ -10,19 +11,15 @@ setup(
author="TorchScale Team", author="TorchScale Team",
author_email="Shuming.Ma@microsoft.com", author_email="Shuming.Ma@microsoft.com",
description="Transformers at any scale", description="Transformers at any scale",
long_description=open("README.md", "r", encoding='utf-8').read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
keywords="Transformers at any scale", keywords="Transformers at any scale",
license="MIT", license="MIT",
url="https://github.com/msranlp/torchscale", url="https://github.com/msranlp/torchscale",
packages=find_packages(exclude=["*.tests", "*.tests.*", packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
"tests.*", "tests"]), install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
install_requires=['apex', python_requires=">=3.8.0",
'torch>=1.8',
'fairscale==0.4.0',
'timm==0.4.12'],
python_requires='>=3.8.0',
classifiers=[ classifiers=[
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
], ],
) )

View File

@ -2,9 +2,10 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import pytest import pytest
import torch
from torchscale.architecture.config import DecoderConfig from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder from torchscale.architecture.decoder import Decoder
import torch
testcases = [ testcases = [
{}, {},
@ -20,7 +21,7 @@ testcases = [
{"multiway": True}, {"multiway": True},
{"share_decoder_input_output_embed": True}, {"share_decoder_input_output_embed": True},
{"checkpoint_activations": True}, {"checkpoint_activations": True},
{"fsdp": True} {"fsdp": True},
] ]

View File

@ -2,9 +2,10 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import pytest import pytest
import torch
from torchscale.architecture.config import EncoderConfig from torchscale.architecture.config import EncoderConfig
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
import torch
testcases = [ testcases = [
{}, {},
@ -20,7 +21,7 @@ testcases = [
{"multiway": True}, {"multiway": True},
{"share_encoder_input_output_embed": True}, {"share_encoder_input_output_embed": True},
{"checkpoint_activations": True}, {"checkpoint_activations": True},
{"fsdp": True} {"fsdp": True},
] ]

View File

@ -2,10 +2,11 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import pytest import pytest
import torch
from torchscale.architecture.config import EncoderDecoderConfig from torchscale.architecture.config import EncoderDecoderConfig
from torchscale.architecture.encoder_decoder import EncoderDecoder from torchscale.architecture.encoder_decoder import EncoderDecoder
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
import torch
testcases = [ testcases = [
{}, {},
@ -16,13 +17,18 @@ testcases = [
{"no_scale_embedding": False}, {"no_scale_embedding": False},
{"layernorm_embedding": True}, {"layernorm_embedding": True},
{"rel_pos_buckets": 32, "max_rel_pos": 256}, {"rel_pos_buckets": 32, "max_rel_pos": 256},
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False}, {
"deepnorm": True,
"subln": False,
"encoder_normalize_before": False,
"decoder_normalize_before": False,
},
{"bert_init": True}, {"bert_init": True},
{"multiway": True}, {"multiway": True},
{"share_decoder_input_output_embed": True}, {"share_decoder_input_output_embed": True},
{"share_all_embeddings": True}, {"share_all_embeddings": True},
{"checkpoint_activations": True}, {"checkpoint_activations": True},
{"fsdp": True} {"fsdp": True},
] ]
@ -33,8 +39,12 @@ def test_decoder(args):
config, config,
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim), encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim), decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim), encoder_embed_positions=PositionalEmbedding(
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim), config.max_source_positions, config.encoder_embed_dim
),
decoder_embed_positions=PositionalEmbedding(
config.max_target_positions, config.decoder_embed_dim
),
) )
src_tokens = torch.ones(2, 20).long() src_tokens = torch.ones(2, 20).long()

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
class EncoderConfig(object): class EncoderConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768) self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
@ -19,9 +20,13 @@ class EncoderConfig(object):
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0) self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False) self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False) self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0) self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
@ -29,7 +34,9 @@ class EncoderConfig(object):
self.subln = kwargs.pop("subln", True) self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False) self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False) self.multiway = kwargs.pop("multiway", False)
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False) self.share_encoder_input_output_embed = kwargs.pop(
"share_encoder_input_output_embed", False
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False) self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text # Text
@ -78,9 +85,13 @@ class DecoderConfig(object):
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0) self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False) self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False) self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0) self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
@ -88,7 +99,9 @@ class DecoderConfig(object):
self.subln = kwargs.pop("subln", True) self.subln = kwargs.pop("subln", True)
self.bert_init = kwargs.pop("bert_init", False) self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False) self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False) self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False) self.no_output_layer = kwargs.pop("no_output_layer", False)
# Text # Text
@ -138,9 +151,13 @@ class EncoderDecoderConfig(object):
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False) self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0) self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True) self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25) self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random") self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False) self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False) self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0) self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0) self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
@ -149,7 +166,9 @@ class EncoderDecoderConfig(object):
self.bert_init = kwargs.pop("bert_init", False) self.bert_init = kwargs.pop("bert_init", False)
self.multiway = kwargs.pop("multiway", False) self.multiway = kwargs.pop("multiway", False)
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False) self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False) self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_source_positions = kwargs.pop("max_source_positions", 1024) self.max_source_positions = kwargs.pop("max_source_positions", 1024)
self.max_target_positions = kwargs.pop("max_target_positions", 1024) self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False) self.no_output_layer = kwargs.pop("no_output_layer", False)

View File

@ -2,22 +2,23 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from fairscale.nn import checkpoint_wrapper, wrap
from apex.normalization import FusedLayerNorm as LayerNorm from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.droppath import DropPath
from torchscale.architecture.utils import init_bert_params
from torchscale.component.relative_position_bias import RelativePositionBias from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
args, args,
@ -31,7 +32,9 @@ class DecoderLayer(nn.Module):
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0: if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth] drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob) self.drop_path = DropPath(drop_path_prob)
else: else:
self.drop_path = None self.drop_path = None
@ -206,7 +209,6 @@ class DecoderLayer(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(
self, self,
args, args,
@ -228,7 +230,11 @@ class Decoder(nn.Module):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_positions = embed_positions self.embed_positions = embed_positions
if output_projection is None and not args.no_output_layer and args.vocab_size > 0: if (
output_projection is None
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args) self.output_projection = self.build_output_projection(args)
else: else:
self.output_projection = output_projection self.output_projection = output_projection
@ -286,7 +292,12 @@ class Decoder(nn.Module):
else: else:
init_scale = math.pow(8.0 * args.decoder_layers, 0.25) init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
for name, p in self.named_parameters(): for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name: if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale) p.data.div_(init_scale)
if args.subln: if args.subln:
@ -295,9 +306,14 @@ class Decoder(nn.Module):
else: else:
init_scale = math.sqrt(math.log(args.decoder_layers * 2)) init_scale = math.sqrt(math.log(args.decoder_layers * 2))
for name, p in self.named_parameters(): for name, p in self.named_parameters():
if 'encoder_attn' in name: if "encoder_attn" in name:
continue continue
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name: if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale) p.data.mul_(init_scale)
def build_output_projection( def build_output_projection(
@ -316,16 +332,12 @@ class Decoder(nn.Module):
args.decoder_embed_dim, args.vocab_size, bias=False args.decoder_embed_dim, args.vocab_size, bias=False
) )
torch.nn.init.normal_( torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5 output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
) )
return output_projection return output_projection
def build_decoder_layer( def build_decoder_layer(
self, self, args, depth, is_moe_layer=False, is_encoder_decoder=False
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
): ):
layer = DecoderLayer( layer = DecoderLayer(
args, args,
@ -347,7 +359,9 @@ class Decoder(nn.Module):
): ):
positions = None positions = None
if self.embed_positions is not None: if self.embed_positions is not None:
positions = self.embed_positions(tokens, incremental_state=incremental_state) positions = self.embed_positions(
tokens, incremental_state=incremental_state
)
if incremental_state is not None: if incremental_state is not None:
tokens = tokens[:, -1:] tokens = tokens[:, -1:]
@ -381,7 +395,9 @@ class Decoder(nn.Module):
**kwargs **kwargs
): ):
# embed tokens and positions # embed tokens and positions
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state) x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
x = x.transpose(0, 1) x = x.transpose(0, 1)
# relative postion # relative postion
@ -389,9 +405,7 @@ class Decoder(nn.Module):
slen = prev_output_tokens.size(1) slen = prev_output_tokens.size(1)
if self.self_attn_relative_position is not None: if self.self_attn_relative_position is not None:
self_attn_rel_pos_bias = self.self_attn_relative_position( self_attn_rel_pos_bias = self.self_attn_relative_position(
batch_size=x.size(1), batch_size=x.size(1), qlen=slen, klen=slen
qlen=slen,
klen=slen
) )
if incremental_state is not None: if incremental_state is not None:
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :] self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
@ -416,7 +430,11 @@ class Decoder(nn.Module):
for idx, layer in enumerate(self.layers): for idx, layer in enumerate(self.layers):
if incremental_state is None: if incremental_state is None:
self_attn_mask = torch.triu( self_attn_mask = torch.triu(
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1 torch.zeros([x.size(0), x.size(0)])
.float()
.fill_(float("-inf"))
.type_as(x),
1,
) )
else: else:
self_attn_mask = None self_attn_mask = None
@ -426,7 +444,9 @@ class Decoder(nn.Module):
x, layer_attn, _, l_aux_i = layer( x, layer_attn, _, l_aux_i = layer(
x, x,
encoder_out["encoder_out"] if encoder_out is not None else None, encoder_out["encoder_out"] if encoder_out is not None else None,
encoder_out["encoder_padding_mask"] if encoder_out is not None else None, encoder_out["encoder_padding_mask"]
if encoder_out is not None
else None,
incremental_state[idx] if incremental_state is not None else None, incremental_state[idx] if incremental_state is not None else None,
self_attn_mask=self_attn_mask, self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask, self_attn_padding_mask=self_attn_padding_mask,
@ -444,7 +464,11 @@ class Decoder(nn.Module):
if not features_only: if not features_only:
x = self.output_layer(x) x = self.output_layer(x)
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]} return x, {
"inner_states": inner_states,
"l_aux": l_aux,
"attn": [layer_attn.mean(dim=0)],
}
def output_layer(self, features): def output_layer(self, features):
return self.output_projection(features) return self.output_projection(features)

View File

@ -2,30 +2,25 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from fairscale.nn import checkpoint_wrapper, wrap
from apex.normalization import FusedLayerNorm as LayerNorm from apex.normalization import FusedLayerNorm as LayerNorm
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multihead_attention import MultiheadAttention from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate from torchscale.component.multiway_network import MultiwayWrapper, set_split_position
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
from torchscale.component.droppath import DropPath
from torchscale.architecture.utils import init_bert_params
from torchscale.component.relative_position_bias import RelativePositionBias from torchscale.component.relative_position_bias import RelativePositionBias
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
def __init__(
self,
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
):
super().__init__() super().__init__()
self.args = args self.args = args
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
@ -34,7 +29,9 @@ class EncoderLayer(nn.Module):
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True) self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
if args.drop_path_rate > 0: if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth] drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob) self.drop_path = DropPath(drop_path_prob)
else: else:
self.drop_path = None self.drop_path = None
@ -49,7 +46,7 @@ class EncoderLayer(nn.Module):
self.build_ffn( self.build_ffn(
self.embed_dim, self.embed_dim,
self.args, self.args,
) ),
) )
else: else:
assert not self.args.multiway assert not self.args.multiway
@ -77,7 +74,12 @@ class EncoderLayer(nn.Module):
if args.deepnorm: if args.deepnorm:
if is_encoder_decoder: if is_encoder_decoder:
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81 self.alpha = (
math.pow(
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
)
* 0.81
)
else: else:
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25) self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
else: else:
@ -107,13 +109,7 @@ class EncoderLayer(nn.Module):
def residual_connection(self, x, residual): def residual_connection(self, x, residual):
return residual * self.alpha + x return residual * self.alpha + x
def forward( def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
self,
x,
encoder_padding_mask,
attn_mask=None,
rel_pos=None
):
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
@ -158,7 +154,6 @@ class EncoderLayer(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
args, args,
@ -179,13 +174,20 @@ class Encoder(nn.Module):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_positions = embed_positions self.embed_positions = embed_positions
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0: if (
output_projection is None
and not is_encoder_decoder
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args) self.output_projection = self.build_output_projection(args)
else: else:
self.output_projection = output_projection self.output_projection = output_projection
if args.layernorm_embedding: if args.layernorm_embedding:
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1) self.layernorm_embedding = MultiwayWrapper(
args, LayerNorm(embed_dim), dim=1
)
else: else:
self.layernorm_embedding = None self.layernorm_embedding = None
@ -199,7 +201,7 @@ class Encoder(nn.Module):
args, args,
depth=i, depth=i,
is_moe_layer=is_moe_layer, is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder is_encoder_decoder=is_encoder_decoder,
) )
) )
self.num_layers = len(self.layers) self.num_layers = len(self.layers)
@ -223,20 +225,39 @@ class Encoder(nn.Module):
if args.deepnorm: if args.deepnorm:
if is_encoder_decoder: if is_encoder_decoder:
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15 init_scale = (
math.pow(
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
)
/ 1.15
)
else: else:
init_scale = math.pow(8.0 * args.encoder_layers, 0.25) init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
for name, p in self.named_parameters(): for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name: if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale) p.data.div_(init_scale)
if args.subln: if args.subln:
if is_encoder_decoder: if is_encoder_decoder:
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3) init_scale = math.sqrt(
math.log(3 * args.decoder_layers)
* math.log(2 * args.encoder_layers)
/ 3
)
else: else:
init_scale = math.sqrt(math.log(args.encoder_layers * 2)) init_scale = math.sqrt(math.log(args.encoder_layers * 2))
for name, p in self.named_parameters(): for name, p in self.named_parameters():
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name: if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale) p.data.mul_(init_scale)
def build_output_projection( def build_output_projection(
@ -244,7 +265,7 @@ class Encoder(nn.Module):
args, args,
): ):
if args.share_encoder_input_output_embed: if args.share_encoder_input_output_embed:
assert args.encoder_embedding_type == 'language' assert args.encoder_embedding_type == "language"
output_projection = torch.nn.Linear( output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1], self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0], self.embed_tokens.weight.shape[0],
@ -256,22 +277,18 @@ class Encoder(nn.Module):
args.encoder_embed_dim, args.vocab_size, bias=False args.encoder_embed_dim, args.vocab_size, bias=False
) )
torch.nn.init.normal_( torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5 output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5
) )
return output_projection return output_projection
def build_encoder_layer( def build_encoder_layer(
self, self, args, depth, is_moe_layer=False, is_encoder_decoder=False
args,
depth,
is_moe_layer=False,
is_encoder_decoder=False
): ):
layer = EncoderLayer( layer = EncoderLayer(
args, args,
depth, depth,
is_moe_layer=is_moe_layer, is_moe_layer=is_moe_layer,
is_encoder_decoder=is_encoder_decoder is_encoder_decoder=is_encoder_decoder,
) )
if args.checkpoint_activations: if args.checkpoint_activations:
layer = checkpoint_wrapper(layer) layer = checkpoint_wrapper(layer)
@ -312,13 +329,12 @@ class Encoder(nn.Module):
if encoder_padding_mask is None: if encoder_padding_mask is None:
if src_tokens is not None: if src_tokens is not None:
encoder_padding_mask = torch.zeros_like( encoder_padding_mask = torch.zeros_like(
src_tokens, src_tokens, device=src_tokens.device
device=src_tokens.device
).bool() ).bool()
else: else:
encoder_padding_mask = torch.zeros( encoder_padding_mask = torch.zeros(
[token_embeddings.size(0), token_embeddings.size(1)], [token_embeddings.size(0), token_embeddings.size(1)],
device=token_embeddings.device device=token_embeddings.device,
).bool() ).bool()
if multiway_split_position is not None: if multiway_split_position is not None:
@ -338,16 +354,13 @@ class Encoder(nn.Module):
rel_pos_bias = None rel_pos_bias = None
if self.relative_position is not None: if self.relative_position is not None:
rel_pos_bias = self.relative_position( rel_pos_bias = self.relative_position(
batch_size=x.size(1), batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
qlen=x.size(0),
klen=x.size(0)
) )
l_aux = [] l_aux = []
for layer in self.layers: for layer in self.layers:
x, l_aux_i = layer( x, l_aux_i = layer(
x, encoder_padding_mask=encoder_padding_mask, x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
rel_pos=rel_pos_bias
) )
if return_all_hiddens: if return_all_hiddens:
assert encoder_states is not None assert encoder_states is not None

View File

@ -2,12 +2,12 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn import torch.nn as nn
from torchscale.architecture.encoder import Encoder
from torchscale.architecture.decoder import Decoder from torchscale.architecture.decoder import Decoder
from torchscale.architecture.encoder import Encoder
class EncoderDecoder(nn.Module): class EncoderDecoder(nn.Module):
def __init__( def __init__(
self, self,
args, args,
@ -51,10 +51,7 @@ class EncoderDecoder(nn.Module):
features_only=False, features_only=False,
**kwargs **kwargs
): ):
encoder_out = self.encoder( encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
src_tokens,
return_all_hiddens=return_all_hiddens
)
decoder_out = self.decoder( decoder_out = self.decoder(
prev_output_tokens, prev_output_tokens,
encoder_out=encoder_out, encoder_out=encoder_out,

View File

@ -2,12 +2,12 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn import torch.nn as nn
from torchscale.component.multihead_attention import MultiheadAttention from torchscale.component.multihead_attention import MultiheadAttention
from torchscale.component.multiway_network import MultiwayNetwork from torchscale.component.multiway_network import MultiwayNetwork
def init_bert_params(module): def init_bert_params(module):
def normal_(data): def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))

View File

@ -1,13 +1,13 @@
# Copyright (c) 2022 Microsoft # Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
from timm.models.layers import drop_path
import torch.nn as nn import torch.nn as nn
from timm.models.layers import drop_path
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
"""
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
@ -16,4 +16,4 @@ class DropPath(nn.Module):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
def extra_repr(self): def extra_repr(self):
return 'p={}'.format(self.drop_prob) return "p={}".format(self.drop_prob)

View File

@ -7,22 +7,12 @@ import torch.nn.functional as F
class VisionLanguageEmbedding(nn.Module): class VisionLanguageEmbedding(nn.Module):
def __init__(self, text_embed, vision_embed):
def __init__(
self,
text_embed,
vision_embed
):
super().__init__() super().__init__()
self.text_embed = text_embed self.text_embed = text_embed
self.vision_embed = vision_embed self.vision_embed = vision_embed
def forward( def forward(self, textual_tokens, visual_tokens, **kwargs):
self,
textual_tokens,
visual_tokens,
**kwargs
):
if textual_tokens is None: if textual_tokens is None:
return self.vision_embed(visual_tokens) return self.vision_embed(visual_tokens)
@ -36,8 +26,8 @@ class VisionLanguageEmbedding(nn.Module):
class VisionEmbedding(nn.Module): class VisionEmbedding(nn.Module):
""" Image to Patch Embedding """Image to Patch Embedding"""
"""
def __init__( def __init__(
self, self,
img_size=224, img_size=224,
@ -45,7 +35,7 @@ class VisionEmbedding(nn.Module):
in_chans=3, in_chans=3,
embed_dim=768, embed_dim=768,
contain_mask_token=False, contain_mask_token=False,
prepend_cls_token=False prepend_cls_token=False,
): ):
super().__init__() super().__init__()
img_size = (img_size, img_size) img_size = (img_size, img_size)
@ -56,7 +46,9 @@ class VisionEmbedding(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
self.num_patches = num_patches self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
if contain_mask_token: if contain_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -68,15 +60,11 @@ class VisionEmbedding(nn.Module):
else: else:
self.cls_token = None self.cls_token = None
def forward( def forward(self, x, masked_position=None, **kwargs):
self,
x,
masked_position=None,
**kwargs
):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert (
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x).flatten(2).transpose(1, 2)
batch_size, seq_len, _ = x.size() batch_size, seq_len, _ = x.size()
@ -88,21 +76,21 @@ class VisionEmbedding(nn.Module):
x = x * (1 - w) + mask_token * w x = x * (1 - w) + mask_token * w
if self.cls_token is not None: if self.cls_token is not None:
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks cls_tokens = self.cls_token.expand(
batch_size, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1) x = torch.cat((cls_tokens, x), dim=1)
return x return x
class TextEmbedding(nn.Embedding): class TextEmbedding(nn.Embedding):
def reset_parameters(self): def reset_parameters(self):
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5) nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
class PositionalEmbedding(nn.Embedding): class PositionalEmbedding(nn.Embedding):
def forward( def forward(
self, self,
x, x,
@ -111,7 +99,9 @@ class PositionalEmbedding(nn.Embedding):
): ):
if positions is None: if positions is None:
# being consistent with Fairseq, which starts from 2. # being consistent with Fairseq, which starts from 2.
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0) positions = (
torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
)
return F.embedding( return F.embedding(
positions, positions,
self.weight, self.weight,

View File

@ -35,13 +35,19 @@ class set_torch_seed(object):
def make_experts(args, embed_dim, expert_ffn_dim): def make_experts(args, embed_dim, expert_ffn_dim):
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() world_size = (
1
if not torch.distributed.is_initialized()
else torch.distributed.get_world_size()
)
expert_list = [] expert_list = []
ddp_rank = args.ddp_rank ddp_rank = args.ddp_rank
start_seed = torch.randint(1000000, (1,)).item() start_seed = torch.randint(1000000, (1,)).item()
# at least as many experts than gpus # at least as many experts than gpus
if args.moe_expert_count >= world_size: if args.moe_expert_count >= world_size:
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}' assert (
args.moe_expert_count % world_size == 0
), f"{args.moe_expert_count}, {world_size}"
local_moe_expert_count = args.moe_expert_count // world_size local_moe_expert_count = args.moe_expert_count // world_size
for i in range(local_moe_expert_count): for i in range(local_moe_expert_count):
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
@ -52,11 +58,13 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
args.subln args.subln,
) )
) )
else: else:
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}' assert (
world_size % args.moe_expert_count == 0
), f"{world_size}, {args.moe_expert_count}"
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
expert_list.append( expert_list.append(
@ -66,7 +74,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
args.activation_fn, args.activation_fn,
args.dropout, args.dropout,
args.activation_dropout, args.activation_dropout,
args.subln args.subln,
) )
) )
experts = nn.ModuleList(expert_list) experts = nn.ModuleList(expert_list)
@ -83,7 +91,6 @@ def get_activation_fn(activation):
class FeedForwardNetwork(nn.Module): class FeedForwardNetwork(nn.Module):
def __init__( def __init__(
self, self,
embed_dim, embed_dim,
@ -91,12 +98,14 @@ class FeedForwardNetwork(nn.Module):
activation_fn, activation_fn,
dropout, dropout,
activation_dropout, activation_dropout,
subln=False subln=False,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.activation_fn = get_activation_fn(activation=str(activation_fn)) self.activation_fn = get_activation_fn(activation=str(activation_fn))
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True) self.activation_dropout_module = torch.nn.Dropout(
activation_dropout, inplace=True
)
self.dropout_module = torch.nn.Dropout(dropout, inplace=True) self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim) self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim) self.fc2 = nn.Linear(ffn_dim, self.embed_dim)

View File

@ -2,15 +2,16 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math import math
import torch import torch
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from apex.normalization import FusedLayerNorm as LayerNorm from apex.normalization import FusedLayerNorm as LayerNorm
from torch import nn
from .multiway_network import MultiwayWrapper from .multiway_network import MultiwayWrapper
class MultiheadAttention(nn.Module): class MultiheadAttention(nn.Module):
def __init__( def __init__(
self, self,
args, args,
@ -25,7 +26,7 @@ class MultiheadAttention(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim**-0.5
self.self_attention = self_attention self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention self.encoder_decoder_attention = encoder_decoder_attention
@ -34,8 +35,14 @@ class MultiheadAttention(nn.Module):
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True)) self.out_proj = MultiwayWrapper(
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None args, nn.Linear(embed_dim, embed_dim, bias=True)
)
self.inner_attn_ln = (
MultiwayWrapper(args, LayerNorm(self.embed_dim))
if subln and self.self_attention
else None
)
self.dropout_module = torch.nn.Dropout(dropout, inplace=True) self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
def reset_parameters(self): def reset_parameters(self):
@ -76,12 +83,20 @@ class MultiheadAttention(nn.Module):
if incremental_state is not None: if incremental_state is not None:
if "prev_key" in incremental_state: if "prev_key" in incremental_state:
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim) prev_key = incremental_state["prev_key"].view(
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim) bsz * self.num_heads, -1, self.head_dim
)
prev_value = incremental_state["prev_value"].view(
bsz * self.num_heads, -1, self.head_dim
)
k = torch.cat([prev_key, k], dim=1) k = torch.cat([prev_key, k], dim=1)
v = torch.cat([prev_value, v], dim=1) v = torch.cat([prev_value, v], dim=1)
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) incremental_state["prev_key"] = k.view(
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) bsz, self.num_heads, -1, self.head_dim
)
incremental_state["prev_value"] = v.view(
bsz, self.num_heads, -1, self.head_dim
)
src_len = k.size(1) src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
@ -103,7 +118,9 @@ class MultiheadAttention(nn.Module):
rel_pos = rel_pos.view(attn_weights.size()) rel_pos = rel_pos.view(attn_weights.size())
attn_weights = attn_weights + rel_pos attn_weights = attn_weights + rel_pos
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
attn_probs = self.dropout_module(attn_weights) attn_probs = self.dropout_module(attn_weights)
attn = torch.bmm(attn_probs, v) attn = torch.bmm(attn_probs, v)

View File

@ -2,6 +2,7 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import copy import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -13,16 +14,14 @@ def MultiwayWrapper(args, module, dim=0):
def set_split_position(position): def set_split_position(position):
def apply_fn(module): def apply_fn(module):
if hasattr(module, 'split_position'): if hasattr(module, "split_position"):
module.split_position = position module.split_position = position
return apply_fn return apply_fn
class MultiwayNetwork(nn.Module): class MultiwayNetwork(nn.Module):
def __init__(self, module, dim=0): def __init__(self, module, dim=0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -36,7 +35,11 @@ class MultiwayNetwork(nn.Module):
return self.A(x, **kwargs) return self.A(x, **kwargs)
if self.split_position == 0: if self.split_position == 0:
return self.B(x, **kwargs) return self.B(x, **kwargs)
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim) x1, x2 = torch.split(
x,
[self.split_position, x.size(self.dim) - self.split_position],
dim=self.dim,
)
# x1, x2 = x[:self.split_position], x[self.split_position:] # x1, x2 = x[:self.split_position], x[self.split_position:]
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs) y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
return torch.cat([y1, y2], dim=self.dim) return torch.cat([y1, y2], dim=self.dim)

View File

@ -2,17 +2,14 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
class RelativePositionBias(nn.Module): class RelativePositionBias(nn.Module):
def __init__( def __init__(
self, self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
bidirectional=True,
num_buckets=32,
max_distance=128,
n_heads=12
): ):
super().__init__() super().__init__()
self.bidirectional = bidirectional self.bidirectional = bidirectional
@ -23,10 +20,7 @@ class RelativePositionBias(nn.Module):
@staticmethod @staticmethod
def _relative_position_bucket( def _relative_position_bucket(
relative_position, relative_position, bidirectional=True, num_buckets=32, max_distance=128
bidirectional=True,
num_buckets=32,
max_distance=128
): ):
ret = 0 ret = 0
n = -relative_position n = -relative_position
@ -41,24 +35,28 @@ class RelativePositionBias(nn.Module):
is_small = n < max_exact is_small = n < max_exact
val_if_large = max_exact + ( val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long) ).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large) ret += torch.where(is_small, n, val_if_large)
return ret return ret
def compute_bias( def compute_bias(self, qlen, klen, step=None):
self,
qlen,
klen,
step=None
):
step = 0 if step is None else step step = 0 if step is None else step
context_position = torch.arange(step, step + qlen, dtype=torch.long, context_position = torch.arange(
device=self.relative_attention_bias.weight.device)[:, None] step,
memory_position = torch.arange(klen, dtype=torch.long, step + qlen,
device=self.relative_attention_bias.weight.device)[None, :] dtype=torch.long,
device=self.relative_attention_bias.weight.device,
)[:, None]
memory_position = torch.arange(
klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen) relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket( rp_bucket = self._relative_position_bucket(
@ -67,16 +65,18 @@ class RelativePositionBias(nn.Module):
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
) )
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = self.relative_attention_bias(
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) rp_bucket
) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(
0
) # shape (1, num_heads, qlen, klen)
return values return values
def forward( def forward(self, batch_size, qlen, klen, step=None):
self,
batch_size,
qlen,
klen,
step=None
):
# shape (batch * num_heads, qlen, klen) # shape (batch * num_heads, qlen, klen)
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen) return (
self.compute_bias(qlen, klen, step)
.repeat(batch_size, 1, 1, 1)
.view(-1, qlen, klen)
)

View File

@ -18,9 +18,9 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.nn import Module, ModuleList from torch.nn import Module, ModuleList
try: try:
from fairseq.modules.moe import MOELayer from fairseq.modules.moe import MOELayer
has_fairseq = True has_fairseq = True
Base = MOELayer Base = MOELayer
except ModuleNotFoundError: except ModuleNotFoundError:
@ -81,8 +81,10 @@ def get_moe_group(moe_expert_count):
else: else:
assert world_size % moe_expert_count == 0 assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count ranks_per_group = world_size // moe_expert_count
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] moe_groups = [
for i in range(moe_expert_count)] [i + j * moe_expert_count for j in range(ranks_per_group)]
for i in range(moe_expert_count)
]
get_moe_group._moe_group_idx = moe_groups get_moe_group._moe_group_idx = moe_groups
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups] get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
@ -105,11 +107,15 @@ def get_all2all_group(moe_expert_count):
else: else:
assert world_size % moe_expert_count == 0 assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count ranks_per_group = world_size // moe_expert_count
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)] all2all_groups = [
for i in range(ranks_per_group)] [i * moe_expert_count + j for j in range(moe_expert_count)]
for i in range(ranks_per_group)
]
get_all2all_group._all2all_group_idx = all2all_groups get_all2all_group._all2all_group_idx = all2all_groups
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups] get_all2all_group._all2all_groups = [
dist.new_group(g) for g in all2all_groups
]
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx) my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
return get_all2all_group._all2all_groups[my_group_idx] return get_all2all_group._all2all_groups[my_group_idx]
@ -133,12 +139,7 @@ class MOELayer(Base):
expert network expert network
""" """
def __init__( def __init__(self, gate, experts, args):
self,
gate,
experts,
args
):
if has_fairseq: if has_fairseq:
super(Base, self).__init__() super(Base, self).__init__()
else: else:
@ -163,9 +164,13 @@ class MOELayer(Base):
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor: def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported" assert len(input) == 1, "only single input Tensor supported"
input = input[0] input = input[0]
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel" assert (
len(input.shape) == 3
), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
if input_padding_mask is not None: if input_padding_mask is not None:
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken" assert (
len(input_padding_mask.shape) == 2
), "input Tensor must have dimensions: (s)equence, (t)oken"
assert input_padding_mask.shape[0] == input.shape[0] assert input_padding_mask.shape[0] == input.shape[0]
assert input_padding_mask.shape[1] == input.shape[1] assert input_padding_mask.shape[1] == input.shape[1]
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts" # assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
@ -174,81 +179,120 @@ class MOELayer(Base):
d_model = input.shape[2] d_model = input.shape[2]
# Pad to expected batch size # Pad to expected batch size
input_shape = list(input.shape) input_shape = list(input.shape)
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0) expected_bsz = (
getattr(self.args, "batch_size", 0)
if self.training
else getattr(self.args, "batch_size_valid", 0)
)
# This indicates that --batch-size or --max-sentences is not specified # This indicates that --batch-size or --max-sentences is not specified
if expected_bsz is None: if expected_bsz is None:
expected_bsz = 0 expected_bsz = 0
# Note: Padding is not necessary at generation time at present # Note: Padding is not necessary at generation time at present
# because all DDP workers process the same batch. Also, batch size at generation time # because all DDP workers process the same batch. Also, batch size at generation time
# can be different from that present in the checkpoint state # can be different from that present in the checkpoint state
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz: if (
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})") not self.in_generation
and expected_bsz != 0
and input_shape[0] != expected_bsz
):
logger.warning(
f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})"
)
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}" assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
padded_input = torch.zeros( padded_input = torch.zeros(
(expected_bsz, input_shape[1], input_shape[2]), (expected_bsz, input_shape[1], input_shape[2]),
dtype=input.dtype, layout=input.layout, device=input.device) dtype=input.dtype,
padded_input[:input_shape[0], :, :] = input layout=input.layout,
device=input.device,
)
padded_input[: input_shape[0], :, :] = input
input = padded_input input = padded_input
padded_input_padding_mask = torch.ones( padded_input_padding_mask = torch.ones(
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device (
expected_bsz,
input_shape[1],
),
dtype=torch.bool,
device=input.device,
) )
if input_padding_mask is not None: if input_padding_mask is not None:
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask padded_input_padding_mask[: input_shape[0], :] = input_padding_mask
else: else:
padded_input_padding_mask[:input_shape[0], :] = False padded_input_padding_mask[: input_shape[0], :] = False
input_padding_mask = padded_input_padding_mask input_padding_mask = padded_input_padding_mask
# Reshape into S tokens by dropping sequence dimension. # Reshape into S tokens by dropping sequence dimension.
reshaped_input = input.reshape(-1, d_model) reshaped_input = input.reshape(-1, d_model)
reshaped_input_shape = reshaped_input.shape reshaped_input_shape = reshaped_input.shape
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None reshaped_input_padding_mask = (
input_padding_mask.reshape(-1) if input_padding_mask is not None else None
)
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences # Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
# Pro of --max-tokens: more flexible for MT variable sequence lengths # Pro of --max-tokens: more flexible for MT variable sequence lengths
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM # Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
if expected_bsz == 0: if expected_bsz == 0:
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device) expected_dim = reshaped_input_shape[0] * torch.ones(
(1,), dtype=torch.long, device=input.device
)
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX) dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
expected_dim = int(expected_dim.item()) expected_dim = int(expected_dim.item())
padded_input = torch.zeros( padded_input = torch.zeros(
(expected_dim, reshaped_input_shape[1]), (expected_dim, reshaped_input_shape[1]),
dtype=input.dtype, layout=input.layout, device=input.device) dtype=input.dtype,
padded_input[:reshaped_input_shape[0], :] = reshaped_input layout=input.layout,
device=input.device,
)
padded_input[: reshaped_input_shape[0], :] = reshaped_input
reshaped_input = padded_input reshaped_input = padded_input
padded_input_padding_mask = torch.ones( padded_input_padding_mask = torch.ones(
(expected_dim,), dtype=torch.bool, device=padded_input.device (expected_dim,), dtype=torch.bool, device=padded_input.device
) )
if reshaped_input_padding_mask is not None: if reshaped_input_padding_mask is not None:
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask padded_input_padding_mask[
: reshaped_input_shape[0]
] = reshaped_input_padding_mask
else: else:
padded_input_padding_mask[:reshaped_input_shape[0]] = False padded_input_padding_mask[: reshaped_input_shape[0]] = False
reshaped_input_padding_mask = padded_input_padding_mask reshaped_input_padding_mask = padded_input_padding_mask
if has_tutel: if has_tutel:
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask) l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(
reshaped_input, reshaped_input_padding_mask
)
S, M = reshaped_input.size(0), reshaped_input.size(1) S, M = reshaped_input.size(0), reshaped_input.size(1)
if not hasattr(self, '_tutel_dispatcher'): if not hasattr(self, "_tutel_dispatcher"):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype) self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E, C, M, dispatch_dtype=reshaped_input.dtype
)
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input) dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else: else:
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask) l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(
reshaped_input, reshaped_input_padding_mask
)
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S dispatch_mask = dispatch_mask.to(input.dtype).permute(
1, 2, 0
) # S,E,C -> E,C,S
E, C, S = dispatch_mask.size() E, C, S = dispatch_mask.size()
M = reshaped_input.size(1) M = reshaped_input.size(1)
assert reshaped_input.size() == (S, M) assert reshaped_input.size() == (S, M)
# einsum("sec,sm->ecm") # einsum("sec,sm->ecm")
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M dispatched_input = torch.mm(
dispatch_mask.view(E * C, S), reshaped_input
) # -> (E*C),M
if self.all2all_size > 1: if self.all2all_size > 1:
dispatched_input = self.all_to_all_wrapper(dispatched_input) dispatched_input = self.all_to_all_wrapper(dispatched_input)
# Re-shape after all-to-all: ecm -> gecm # Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model) dispatched_input = dispatched_input.reshape(
self.all2all_size, self.num_local_experts, -1, d_model
)
chunks = dispatched_input.chunk(self.num_local_experts, dim=1) chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
expert_outputs = [] expert_outputs = []
for chunk, expert in zip(chunks, self.experts): for chunk, expert in zip(chunks, self.experts):
@ -259,18 +303,24 @@ class MOELayer(Base):
expert_output = self.all_to_all_wrapper(expert_output) expert_output = self.all_to_all_wrapper(expert_output)
# Re-shape back: gecm -> ecm # Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model) expert_output = expert_output.reshape(
self.all2all_size * self.num_local_experts, -1, d_model
)
if has_tutel: if has_tutel:
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M)) combined_output = self._tutel_dispatcher.decode(
expert_output.view(E * C, M)
)
else: else:
# einsum("sec,ecm->sm") # einsum("sec,ecm->sm")
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M)) combined_output = combine_weights.view(S, E * C).mm(
expert_output.view(E * C, M)
)
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences # Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
combined_output = combined_output[:reshaped_input_shape[0], :] combined_output = combined_output[: reshaped_input_shape[0], :]
combined_output = combined_output.reshape(input.shape) combined_output = combined_output.reshape(input.shape)
combined_output = combined_output[:input_shape[0], :, :] combined_output = combined_output[: input_shape[0], :, :]
self.record_all_to_all_stats() self.record_all_to_all_stats()
@ -280,7 +330,7 @@ class MOELayer(Base):
self.in_generation = True self.in_generation = True
def all_to_all_wrapper(self, input: Tensor): def all_to_all_wrapper(self, input: Tensor):
dummy_a2a = getattr(self.args, 'dummy_a2a', False) dummy_a2a = getattr(self.args, "dummy_a2a", False)
if dummy_a2a: if dummy_a2a:
input = input.contiguous() input = input.contiguous()
output = input.detach().clone() output = input.detach().clone()
@ -294,13 +344,13 @@ class MOELayer(Base):
output = _AllToAll.apply(self.all2all_group, input) output = _AllToAll.apply(self.all2all_group, input)
cuda_end.record() cuda_end.record()
cpu_end = time.time() * 1000 cpu_end = time.time() * 1000
self.a2a_cpu_time_ms += (cpu_end - cpu_start) self.a2a_cpu_time_ms += cpu_end - cpu_start
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end)) self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
return output return output
def record_all_to_all_stats(self): def record_all_to_all_stats(self):
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize() # controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False) record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False)
if record_a2a_perf_stats: if record_a2a_perf_stats:
torch.cuda.synchronize() torch.cuda.synchronize()
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms

View File

@ -13,14 +13,14 @@
# NOTE: This is a mirror of the code in # NOTE: This is a mirror of the code in
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe # https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
from typing import Callable, Dict, Tuple, Optional
import math import math
import torch from typing import Callable, Dict, Optional, Tuple
from torch import Tensor
import torch.nn.functional as F
from .moe_layer import has_tutel, fused_cumsum_sub_one import torch
import torch.nn.functional as F
from torch import Tensor
from .moe_layer import fused_cumsum_sub_one, has_tutel
# use a fixed temperature to compute balance loss # use a fixed temperature to compute balance loss
TEMPERATURE_FOR_L_UAX = 0.07 TEMPERATURE_FOR_L_UAX = 0.07
@ -65,13 +65,22 @@ def top1gating(
indices1_s = torch.argmax(gates, dim=1) indices1_s = torch.argmax(gates, dim=1)
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True) mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
if input_mask is not None and input_mask.any(): if input_mask is not None and input_mask.any():
nonpadding = ~ input_mask nonpadding = ~input_mask
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
# for logging (percent of tokens routed to each expert) # for logging (percent of tokens routed to each expert)
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens expert1_hist = (
100
* torch.histc(
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
)
/ num_tokens
)
metadata["unused_expert1_count"] = (expert1_hist == 0).sum() metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny expert1_hist = (
torch.sort(expert1_hist, dim=0, descending=True).values
+ torch.finfo(torch.float32).tiny
)
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
@ -91,7 +100,21 @@ def top1gating(
if has_tutel: if has_tutel:
locations1_s = torch.sum(locations1 * mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ] return (
l_aux,
metadata,
capacity,
num_experts,
[
indices1_s,
],
[
locations1_s,
],
[
gates1_s,
],
)
# Remove locations outside capacity from mask # Remove locations outside capacity from mask
mask1 = mask1 * torch.lt(locations1, capacity) mask1 = mask1 * torch.lt(locations1, capacity)
@ -104,7 +127,8 @@ def top1gating(
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True) locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
combine1_sec = torch.bmm( combine1_sec = torch.bmm(
# einsum("se,sc->sec") # einsum("se,sc->sec")
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1) gates1.unsqueeze(-1),
locations1_sc.to(gates1.dtype).unsqueeze(1),
) )
dispatch_mask = combine1_sec.bool() dispatch_mask = combine1_sec.bool()
if use_fp32: if use_fp32:
@ -218,10 +242,10 @@ def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) ->
if unsqueeze_indices: if unsqueeze_indices:
indices = indices.unsqueeze(-1) indices = indices.unsqueeze(-1)
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1" assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype) output = torch.zeros(
output.scatter_( indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype
len(output.shape) - 1, indices, 1
) )
output.scatter_(len(output.shape) - 1, indices, 1)
return output return output
@ -235,7 +259,7 @@ def top2gating(
logits: torch.Tensor, logits: torch.Tensor,
input_mask: Optional[torch.Tensor] = None, input_mask: Optional[torch.Tensor] = None,
use_fp32=False, use_fp32=False,
second_expert_policy='sampling', second_expert_policy="sampling",
normalize_gate_prob_before_dropping=False, normalize_gate_prob_before_dropping=False,
eval_mode=False, eval_mode=False,
moe_eval_capacity_token_fraction=0.25, moe_eval_capacity_token_fraction=0.25,
@ -260,7 +284,7 @@ def top2gating(
# Create a mask for 1st's expert per token # Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1, keepdim=True) indices1_s = torch.argmax(gates, dim=1, keepdim=True)
mask1 = one_hot(indices1_s, num_experts) mask1 = one_hot(indices1_s, num_experts)
if second_expert_policy == 'sampling': if second_expert_policy == "sampling":
# Create a mask for 2nd's expert per token using Gumbel-max trick # Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
@ -281,13 +305,13 @@ def top2gating(
gates1_s = gates1_s / denom_s gates1_s = gates1_s / denom_s
gates2_s = gates2_s / denom_s gates2_s = gates2_s / denom_s
if second_expert_policy == 'random': if second_expert_policy == "random":
sampled = (2 * gates2_s) > torch.rand_like(gates2_s) sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0) mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
# Compute locations in capacity buffer # Compute locations in capacity buffer
if input_mask is not None and input_mask.any(): if input_mask is not None and input_mask.any():
nonpadding = ~ input_mask nonpadding = ~input_mask
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype) mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype) mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
@ -296,15 +320,22 @@ def top2gating(
importance_scores = -1 * gates.max(dim=1)[0] importance_scores = -1 * gates.max(dim=1)[0]
sorted_mask1 = mask1[importance_scores.argsort(dim=0)] sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1 sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)] importance_sorted_locations1 = sorted_cumsum1[
importance_scores.argsort(dim=0).argsort(dim=0)
]
sorted_mask2 = mask2[importance_scores.argsort(dim=0)] sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2 sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)] importance_sorted_locations2 = sorted_cumsum2[
importance_scores.argsort(dim=0).argsort(dim=0)
]
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True) importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2 locations1, locations2 = (
importance_sorted_locations1,
importance_sorted_locations2,
)
else: else:
locations1 = fused_cumsum_sub_one(mask1) locations1 = fused_cumsum_sub_one(mask1)
locations2 = fused_cumsum_sub_one(mask2) locations2 = fused_cumsum_sub_one(mask2)
@ -318,8 +349,12 @@ def top2gating(
l_aux = l_aux * num_experts * num_experts l_aux = l_aux * num_experts * num_experts
# for logging purposes # for logging purposes
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1) metadata["overflow_expert1"] = (
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2) 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
)
metadata["overflow_expert2"] = (
100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
)
# Remove locations outside capacity from mask # Remove locations outside capacity from mask
mask1_, mask2_ = mask1, mask2 mask1_, mask2_ = mask1, mask2
@ -327,13 +362,31 @@ def top2gating(
mask2 = mask2 * torch.lt(locations2, capacity) mask2 = mask2 * torch.lt(locations2, capacity)
# for logging (percent of tokens routed to each expert) # for logging (percent of tokens routed to each expert)
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens expert1_hist = (
100
* torch.histc(
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
)
/ num_tokens
)
metadata["unused_expert1_count"] = (expert1_hist == 0).sum() metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny expert1_hist = (
torch.sort(expert1_hist, dim=0, descending=True).values
+ torch.finfo(torch.float32).tiny
)
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens expert2_hist = (
100
* torch.histc(
(indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
)
/ num_tokens
)
metadata["unused_expert2_count"] = (expert2_hist == 0).sum() metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny expert2_hist = (
torch.sort(expert2_hist, dim=0, descending=True).values
+ torch.finfo(torch.float32).tiny
)
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1) sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum() metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
@ -355,8 +408,15 @@ def top2gating(
if has_tutel: if has_tutel:
locations1_s = torch.sum(locations1 * mask1_, dim=1) locations1_s = torch.sum(locations1 * mask1_, dim=1)
locations2_s = torch.sum(locations2 * mask2_, dim=1) locations2_s = torch.sum(locations2 * mask2_, dim=1)
return l_aux, metadata, capacity, num_experts, \ return (
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s] l_aux,
metadata,
capacity,
num_experts,
[indices1_s, indices2_s],
[locations1_s, locations2_s],
[gates1_s, gates2_s],
)
# Store the capacity location for each token # Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1) locations1_s = torch.sum(locations1 * mask1, dim=1)
@ -369,11 +429,13 @@ def top2gating(
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True) locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
combine1_sec = torch.bmm( combine1_sec = torch.bmm(
# einsum("se,sc->sec") # einsum("se,sc->sec")
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1) gates1.unsqueeze(-1),
locations1_sc.to(gates1.dtype).unsqueeze(1),
) )
combine2_sec = torch.bmm( combine2_sec = torch.bmm(
# einsum("se,sc->sec") # einsum("se,sc->sec")
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1) gates2.unsqueeze(-1),
locations2_sc.to(gates2.dtype).unsqueeze(1),
) )
combine_weights = combine1_sec + combine2_sec combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool() dispatch_mask = combine_weights.bool()
@ -406,7 +468,7 @@ class Top2Gate(torch.nn.Module):
model_dim: int, model_dim: int,
num_experts: int, num_experts: int,
use_fp32=False, use_fp32=False,
second_expert_policy='sampling', second_expert_policy="sampling",
normalize_gate_prob_before_dropping=False, normalize_gate_prob_before_dropping=False,
moe_eval_capacity_token_fraction=0.25, moe_eval_capacity_token_fraction=0.25,
batch_prioritized_routing=False, batch_prioritized_routing=False,

View File

@ -3,37 +3,35 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchscale.architecture.encoder import Encoder from torchscale.architecture.encoder import Encoder
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding from torchscale.component.embedding import (
PositionalEmbedding,
TextEmbedding,
VisionEmbedding,
)
from torchscale.component.multiway_network import MultiwayWrapper from torchscale.component.multiway_network import MultiwayWrapper
class BEiT3(nn.Module): class BEiT3(nn.Module):
def __init__(self, args, **kwargs): def __init__(self, args, **kwargs):
super().__init__() super().__init__()
self.args = args self.args = args
assert args.multiway assert args.multiway
assert args.vocab_size > 0 assert args.vocab_size > 0
assert not args.share_encoder_input_output_embed assert not args.share_encoder_input_output_embed
self.text_embed = TextEmbedding( self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
args.vocab_size,
args.encoder_embed_dim
)
self.vision_embed = VisionEmbedding( self.vision_embed = VisionEmbedding(
args.img_size, args.img_size,
args.patch_size, args.patch_size,
args.in_chans, args.in_chans,
args.encoder_embed_dim, args.encoder_embed_dim,
contain_mask_token=True, contain_mask_token=True,
prepend_cls_token=True prepend_cls_token=True,
) )
embed_positions = MultiwayWrapper( embed_positions = MultiwayWrapper(
args, args,
PositionalEmbedding( PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
args.max_source_positions,
args.encoder_embed_dim
),
dim=1, dim=1,
) )
self.encoder = Encoder( self.encoder = Encoder(
@ -71,7 +69,7 @@ class BEiT3(nn.Module):
encoder_padding_mask = torch.cat( encoder_padding_mask = torch.cat(
[ [
torch.zeros(x1.shape[:-1]).to(x1.device).bool(), torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
text_padding_position text_padding_position,
], ],
dim=1, dim=1,
) )