Code reformatting
This commit is contained in:
parent
1354614d44
commit
7eca1a531c
|
@ -4,7 +4,6 @@
|
|||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.generate import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.interactive import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,24 +2,24 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models.transformer import (
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||
)
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from fairseq.models.squad import SQuADHead
|
||||
from omegaconf import II
|
||||
from .machine_translation import MTEncoder as Encoder
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from fairseq import utils
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||
from fairseq.models.squad import SQuADHead
|
||||
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from omegaconf import II
|
||||
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
|
||||
from .machine_translation import MTEncoder as Encoder
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
|
||||
|
@ -109,7 +109,7 @@ class BertConfig(FairseqDataclass):
|
|||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
max_source_positions: int = field(
|
||||
default=1024, metadata={"help": "max source positions"}
|
||||
|
@ -118,59 +118,41 @@ class BertConfig(FairseqDataclass):
|
|||
default="relu", metadata={"help": "activation function to use for pooler layer"}
|
||||
)
|
||||
pooler_dropout: float = field(
|
||||
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
|
||||
default=0.0,
|
||||
metadata={"help": "dropout probability in the masked_lm pooler layers"},
|
||||
)
|
||||
# options from other parts of the config
|
||||
# add_bos_token: bool = II("task.add_bos_token")
|
||||
# tokens_per_sample: int = II("task.tokens_per_sample")
|
||||
tpu: bool = II("common.tpu")
|
||||
rel_pos_buckets: int = field(
|
||||
default=0, metadata={"help": ""}
|
||||
)
|
||||
max_rel_pos: int = field(
|
||||
default=0, metadata={"help": ""}
|
||||
)
|
||||
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
|
||||
max_rel_pos: int = field(default=0, metadata={"help": ""})
|
||||
moe_freq: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Frequency at which we insert MoE Transformer layers"
|
||||
},
|
||||
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
||||
)
|
||||
moe_expert_count: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Number of experts in each MoE Layer"
|
||||
}
|
||||
default=0, metadata={"help": "Number of experts in each MoE Layer"}
|
||||
)
|
||||
moe_gating_use_fp32: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use FP32 computations in MoE top2 gating function"
|
||||
}
|
||||
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
|
||||
)
|
||||
moe_second_expert_policy: str = field(
|
||||
default='sampling',
|
||||
metadata={
|
||||
"help": "policy for second expert, options: all/sampling/random"
|
||||
}
|
||||
default="sampling",
|
||||
metadata={"help": "policy for second expert, options: all/sampling/random"},
|
||||
)
|
||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
||||
}
|
||||
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
|
||||
},
|
||||
)
|
||||
moe_expert_ffn_dim: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "MoE expert FFN dimension"
|
||||
}
|
||||
default=None, metadata={"help": "MoE expert FFN dimension"}
|
||||
)
|
||||
moe_top1_expert: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use top1 gate instead of top2"
|
||||
}
|
||||
default=False, metadata={"help": "Use top1 gate instead of top2"}
|
||||
)
|
||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||
default=0.25,
|
||||
|
@ -179,23 +161,29 @@ class BertConfig(FairseqDataclass):
|
|||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
moe_normalize_expert_grad: Optional[str] = field(
|
||||
default='world_size',
|
||||
default="world_size",
|
||||
metadata={
|
||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||
}
|
||||
},
|
||||
)
|
||||
record_a2a_perf_stats: Optional[bool] = field(
|
||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||
default=False,
|
||||
metadata={"help": "records all to all perf stats during distributed training"},
|
||||
)
|
||||
dummy_a2a: Optional[bool] = field(
|
||||
default=False, metadata={
|
||||
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "By passes all to all during distributed training by returning the input buffer as output"
|
||||
},
|
||||
)
|
||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if true orders token by the gate prob before capacity dropping."
|
||||
},
|
||||
)
|
||||
ddp_rank: int = II("distributed_training.distributed_rank")
|
||||
deepnorm: Optional[bool] = field(
|
||||
|
@ -208,7 +196,6 @@ class BertConfig(FairseqDataclass):
|
|||
|
||||
@register_model("mlm", dataclass=BertConfig)
|
||||
class BertModel(BaseFairseqModel):
|
||||
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
@ -240,7 +227,11 @@ class BertModel(BaseFairseqModel):
|
|||
)
|
||||
|
||||
lm_head = cls.build_lm_head(
|
||||
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
|
||||
args,
|
||||
args.encoder_embed_dim,
|
||||
len(task.dictionary),
|
||||
args.activation_fn,
|
||||
weight=embed_tokens.weight,
|
||||
)
|
||||
|
||||
config = EncoderConfig()
|
||||
|
@ -269,7 +260,9 @@ class BertModel(BaseFairseqModel):
|
|||
def output_layer(self, features, masked_tokens=None):
|
||||
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
||||
|
||||
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
||||
def register_classification_head(
|
||||
self, name, num_classes=None, inner_dim=None, **kwargs
|
||||
):
|
||||
"""Register a classification head."""
|
||||
if name in self.classification_heads:
|
||||
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
||||
|
@ -277,7 +270,7 @@ class BertModel(BaseFairseqModel):
|
|||
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
||||
logger.warning(
|
||||
're-registering head "{}" with num_classes {} (prev: {}) '
|
||||
'and inner_dim {} (prev: {})'.format(
|
||||
"and inner_dim {} (prev: {})".format(
|
||||
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
||||
)
|
||||
)
|
||||
|
@ -295,42 +288,51 @@ class BertModel(BaseFairseqModel):
|
|||
)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + '.' if name != '' else ''
|
||||
prefix = name + "." if name != "" else ""
|
||||
|
||||
# upgrade children modules
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
|
||||
# Handle new classification heads present in the state dict.
|
||||
current_head_names = (
|
||||
[] if not hasattr(self, 'classification_heads')
|
||||
[]
|
||||
if not hasattr(self, "classification_heads")
|
||||
else self.classification_heads.keys()
|
||||
)
|
||||
keys_to_delete = []
|
||||
for k in state_dict.keys():
|
||||
if not k.startswith(prefix + 'classification_heads.'):
|
||||
if not k.startswith(prefix + "classification_heads."):
|
||||
continue
|
||||
|
||||
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
|
||||
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
|
||||
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
|
||||
head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203
|
||||
num_classes = state_dict[
|
||||
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
||||
].size(0)
|
||||
inner_dim = state_dict[
|
||||
prefix + "classification_heads." + head_name + ".dense.weight"
|
||||
].size(0)
|
||||
|
||||
if getattr(self.args, 'load_checkpoint_heads', False):
|
||||
if getattr(self.args, "load_checkpoint_heads", False):
|
||||
if head_name not in current_head_names:
|
||||
self.register_classification_head(head_name, num_classes, inner_dim)
|
||||
else:
|
||||
if head_name not in current_head_names:
|
||||
logger.warning(
|
||||
'deleting classification head ({}) from checkpoint '
|
||||
'not present in current model: {}'.format(head_name, k)
|
||||
"deleting classification head ({}) from checkpoint "
|
||||
"not present in current model: {}".format(head_name, k)
|
||||
)
|
||||
keys_to_delete.append(k)
|
||||
elif (
|
||||
num_classes != self.classification_heads[head_name].out_proj.out_features
|
||||
or inner_dim != self.classification_heads[head_name].dense.out_features
|
||||
num_classes
|
||||
!= self.classification_heads[head_name].out_proj.out_features
|
||||
or inner_dim
|
||||
!= self.classification_heads[head_name].dense.out_features
|
||||
):
|
||||
logger.warning(
|
||||
'deleting classification head ({}) from checkpoint '
|
||||
'with different dimensions than current model: {}'.format(head_name, k)
|
||||
"deleting classification head ({}) from checkpoint "
|
||||
"with different dimensions than current model: {}".format(
|
||||
head_name, k
|
||||
)
|
||||
)
|
||||
keys_to_delete.append(k)
|
||||
for k in keys_to_delete:
|
||||
|
@ -338,12 +340,12 @@ class BertModel(BaseFairseqModel):
|
|||
|
||||
# Copy any newly-added classification heads into the state dict
|
||||
# with their current weights.
|
||||
if hasattr(self, 'classification_heads'):
|
||||
if hasattr(self, "classification_heads"):
|
||||
cur_state = self.classification_heads.state_dict()
|
||||
for k, v in cur_state.items():
|
||||
if prefix + 'classification_heads.' + k not in state_dict:
|
||||
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
||||
state_dict[prefix + 'classification_heads.' + k] = v
|
||||
if prefix + "classification_heads." + k not in state_dict:
|
||||
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
||||
state_dict[prefix + "classification_heads." + k] = v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -354,7 +356,9 @@ class BertModel(BaseFairseqModel):
|
|||
masked_tokens=None,
|
||||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
||||
encoder_out = self.encoder(
|
||||
src_tokens, features_only=True, return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
x, extra = encoder_out["encoder_out"], encoder_out
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
|
@ -455,7 +459,7 @@ def base_unilm_architecture(args):
|
|||
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
|
||||
|
||||
# Model training is not stable without this
|
||||
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
|
||||
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
||||
|
|
|
@ -9,10 +9,9 @@
|
|||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq import distributed_utils
|
||||
import torch
|
||||
from fairseq import distributed_utils, utils
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import (
|
||||
FairseqIncrementalDecoder,
|
||||
|
@ -20,14 +19,13 @@ from fairseq.models import (
|
|||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import (
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
|
||||
)
|
||||
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
from torchscale.architecture.config import DecoderConfig
|
||||
from omegaconf import II
|
||||
|
||||
from torchscale.architecture.config import DecoderConfig
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -104,49 +102,34 @@ class LanguageConfig(FairseqDataclass):
|
|||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
moe_freq: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Frequency at which we insert MoE Transformer layers"
|
||||
},
|
||||
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
||||
)
|
||||
moe_expert_count: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Number of experts in each MoE Layer"
|
||||
}
|
||||
default=0, metadata={"help": "Number of experts in each MoE Layer"}
|
||||
)
|
||||
moe_gating_use_fp32: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use FP32 computations in MoE top2 gating function"
|
||||
}
|
||||
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
|
||||
)
|
||||
moe_second_expert_policy: str = field(
|
||||
default='sampling',
|
||||
metadata={
|
||||
"help": "policy for second expert, options: all/sampling/random"
|
||||
}
|
||||
default="sampling",
|
||||
metadata={"help": "policy for second expert, options: all/sampling/random"},
|
||||
)
|
||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
||||
}
|
||||
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
|
||||
},
|
||||
)
|
||||
moe_expert_ffn_dim: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "MoE expert FFN dimension"
|
||||
}
|
||||
default=None, metadata={"help": "MoE expert FFN dimension"}
|
||||
)
|
||||
moe_top1_expert: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use top1 gate instead of top2"
|
||||
}
|
||||
default=False, metadata={"help": "Use top1 gate instead of top2"}
|
||||
)
|
||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||
default=0.25,
|
||||
|
@ -155,23 +138,29 @@ class LanguageConfig(FairseqDataclass):
|
|||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
moe_normalize_expert_grad: Optional[str] = field(
|
||||
default='world_size',
|
||||
default="world_size",
|
||||
metadata={
|
||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||
}
|
||||
},
|
||||
)
|
||||
record_a2a_perf_stats: Optional[bool] = field(
|
||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
||||
default=False,
|
||||
metadata={"help": "records all to all perf stats during distributed training"},
|
||||
)
|
||||
dummy_a2a: Optional[bool] = field(
|
||||
default=False, metadata={
|
||||
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "By passes all to all during distributed training by returning the input buffer as output"
|
||||
},
|
||||
)
|
||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if true orders token by the gate prob before capacity dropping."
|
||||
},
|
||||
)
|
||||
use_xmoe: Optional[bool] = field(
|
||||
default=False,
|
||||
|
@ -205,7 +194,6 @@ class LanguageConfig(FairseqDataclass):
|
|||
|
||||
@register_model("lm", dataclass=LanguageConfig)
|
||||
class LanguageModel(FairseqLanguageModel):
|
||||
|
||||
def __init__(self, args, decoder):
|
||||
self.args = args
|
||||
super().__init__(decoder)
|
||||
|
@ -245,19 +233,17 @@ class LanguageModel(FairseqLanguageModel):
|
|||
args.decoder_embed_dim, len(task.dictionary), bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(args, 'moe_freq', 0) > 0
|
||||
and (
|
||||
getattr(args, 'fp16', False)
|
||||
and not getattr(args, 'memory_efficient_fp16', False)
|
||||
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
||||
)
|
||||
if getattr(args, "moe_freq", 0) > 0 and (
|
||||
getattr(args, "fp16", False)
|
||||
and not getattr(args, "memory_efficient_fp16", False)
|
||||
and getattr(args, "ddp_backend", None) != "fully_sharded"
|
||||
):
|
||||
assert args.fp16_no_flatten_grads, \
|
||||
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||
assert (
|
||||
args.fp16_no_flatten_grads
|
||||
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||
|
||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||
|
||||
|
@ -281,7 +267,6 @@ class LanguageModel(FairseqLanguageModel):
|
|||
|
||||
|
||||
class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||
|
||||
def forward(self, src_tokens, **kwargs):
|
||||
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
|
||||
|
|
|
@ -6,12 +6,12 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq import distributed_utils, utils
|
||||
from fairseq.distributed import utils as fsdp_wrap
|
||||
from fairseq import distributed_utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
|
@ -20,12 +20,13 @@ from fairseq.models import (
|
|||
)
|
||||
from fairseq.models.transformer import Embedding
|
||||
from fairseq.modules import PositionalEmbedding
|
||||
from torch import Tensor
|
||||
|
||||
from torchscale.architecture.config import DecoderConfig, EncoderConfig
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
||||
|
||||
from .language_modeling import LMDecoder as MTDecoder
|
||||
|
||||
from torch import Tensor
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
|
@ -35,7 +36,6 @@ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
|||
|
||||
@register_model("mt")
|
||||
class TranslationModel(FairseqEncoderDecoderModel):
|
||||
|
||||
def __init__(self, args, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
self.args = args
|
||||
|
@ -269,7 +269,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
args.decoder_embed_dim, len(tgt_dict), bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
||||
)
|
||||
|
||||
encoder = cls.build_encoder(
|
||||
|
@ -320,7 +320,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
|
||||
def build_decoder(
|
||||
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
||||
):
|
||||
config = DecoderConfig()
|
||||
config.override(args)
|
||||
|
||||
|
@ -342,10 +344,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
features_only: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
|
@ -365,15 +364,20 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
|||
|
||||
|
||||
class MTEncoder(Encoder, FairseqEncoder):
|
||||
|
||||
def forward(self, src_tokens, **kwargs):
|
||||
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
|
||||
return super().forward(
|
||||
src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs
|
||||
)
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
|
||||
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
|
||||
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
|
||||
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
|
||||
0, new_order
|
||||
)
|
||||
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(
|
||||
0, new_order
|
||||
)
|
||||
|
||||
encoder_states = encoder_out["encoder_states"]
|
||||
if len(encoder_states) > 0:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import torch
|
||||
from infinibatch.iterators import CheckpointableIterator
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
|
@ -25,7 +26,6 @@ class BaseBatchGen(CheckpointableIterator):
|
|||
raise NotImplementedError()
|
||||
|
||||
def _move_to_tensor(self, batch):
|
||||
|
||||
def to_tensor(x):
|
||||
return torch.tensor(x)
|
||||
|
||||
|
|
|
@ -1,32 +1,32 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import itertools
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from infinibatch import iterators
|
||||
|
||||
from .basic_loader import BaseBatchGen
|
||||
from .utils import NativeCheckpointableIterator, WeightIterator
|
||||
|
||||
|
||||
class MLMLoader(BaseBatchGen):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dataset,
|
||||
dictionary,
|
||||
tokenizer,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
self,
|
||||
args,
|
||||
dataset,
|
||||
dictionary,
|
||||
tokenizer,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
@ -62,9 +62,7 @@ class MLMLoader(BaseBatchGen):
|
|||
log_empty_buffer_warning=True and self.shard_id == 0,
|
||||
)
|
||||
|
||||
prefetch_batches = iterators.MapIterator(
|
||||
prefetch_batches, self._move_to_tensor
|
||||
)
|
||||
prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)
|
||||
|
||||
self._iter = prefetch_batches
|
||||
|
||||
|
@ -73,25 +71,25 @@ class MLMLoader(BaseBatchGen):
|
|||
weights = []
|
||||
|
||||
for data in self.data:
|
||||
multilingual_iters.append(
|
||||
self._tokenize(data)
|
||||
)
|
||||
if 'weight' in data:
|
||||
weights.append(float(data['weight']))
|
||||
multilingual_iters.append(self._tokenize(data))
|
||||
if "weight" in data:
|
||||
weights.append(float(data["weight"]))
|
||||
else:
|
||||
weights.append(int(data['count']))
|
||||
weights.append(int(data["count"]))
|
||||
|
||||
if len(multilingual_iters) == 1:
|
||||
return multilingual_iters[0]
|
||||
|
||||
sampling_iterator = WeightIterator(weights)
|
||||
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
||||
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
||||
tokenized_lines = iterators.MultiplexIterator(
|
||||
control_iterator, multilingual_iters
|
||||
)
|
||||
|
||||
return tokenized_lines
|
||||
|
||||
def _tokenize(self, data):
|
||||
'''
|
||||
"""
|
||||
data:
|
||||
{
|
||||
'source': list[Path],
|
||||
|
@ -100,33 +98,35 @@ class MLMLoader(BaseBatchGen):
|
|||
'weight': float,
|
||||
'name': str,
|
||||
}
|
||||
'''
|
||||
"""
|
||||
dataset = list(
|
||||
zip(
|
||||
data['source'],
|
||||
itertools.repeat(data['source_lang']),
|
||||
)
|
||||
zip(
|
||||
data["source"],
|
||||
itertools.repeat(data["source_lang"]),
|
||||
)
|
||||
)
|
||||
|
||||
if self.shuffle:
|
||||
chunk_files = \
|
||||
iterators.InfinitePermutationSourceIterator(
|
||||
dataset,
|
||||
seed=self.seed,
|
||||
shuffle=self.shuffle,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
chunk_files = iterators.InfinitePermutationSourceIterator(
|
||||
dataset,
|
||||
seed=self.seed,
|
||||
shuffle=self.shuffle,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
else:
|
||||
chunk_files = \
|
||||
iterators.ChunkedSourceIterator(
|
||||
dataset,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
chunk_files = iterators.ChunkedSourceIterator(
|
||||
dataset,
|
||||
num_instances=self.num_shards,
|
||||
instance_rank=self.shard_id,
|
||||
)
|
||||
|
||||
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
||||
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
||||
tokenized_lines = iterators.SelectManyIterator(
|
||||
chunk_files, lambda files: self._read_from_files(*files)
|
||||
)
|
||||
tokenized_lines = iterators.SamplingRandomMapIterator(
|
||||
tokenized_lines, self._prepare, self.seed
|
||||
)
|
||||
|
||||
return tokenized_lines
|
||||
|
||||
|
@ -134,22 +134,29 @@ class MLMLoader(BaseBatchGen):
|
|||
|
||||
if self.max_sentences is not None:
|
||||
if self.batch_read_ahead > 0:
|
||||
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
||||
lines = iterators.BlockwiseShuffleIterator(
|
||||
lines, self.batch_read_ahead, self.seed
|
||||
)
|
||||
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
|
||||
else:
|
||||
|
||||
def dynamic_batch_size(sample):
|
||||
lengths = [len(x) for x in sample]
|
||||
batch_size = self.max_tokens // max(lengths)
|
||||
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
|
||||
batch_size = (
|
||||
batch_size
|
||||
// self.required_batch_size_multiple
|
||||
* self.required_batch_size_multiple
|
||||
)
|
||||
return max(1, batch_size)
|
||||
|
||||
batches = iterators.BucketedReadaheadBatchIterator(
|
||||
lines,
|
||||
read_ahead=self.batch_read_ahead,
|
||||
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||
batch_size=dynamic_batch_size,
|
||||
shuffle=self.shuffle,
|
||||
seed=self.seed,
|
||||
lines,
|
||||
read_ahead=self.batch_read_ahead,
|
||||
key=(lambda x: max(len(x[0]), len(x[1]))) if self.shuffle else None,
|
||||
batch_size=dynamic_batch_size,
|
||||
shuffle=self.shuffle,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
def collate(batch):
|
||||
|
@ -160,38 +167,56 @@ class MLMLoader(BaseBatchGen):
|
|||
s2s_source_max_length = max([len(x[2]) for x in batch])
|
||||
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||
|
||||
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
||||
fill_value=self.dictionary.pad())
|
||||
mlm_source_ids = np.full(
|
||||
shape=(batch_size, mlm_source_max_length),
|
||||
dtype=np.int32,
|
||||
fill_value=self.dictionary.pad(),
|
||||
)
|
||||
mlm_target_ids = np.full(
|
||||
shape=(batch_size, mlm_target_max_length),
|
||||
dtype=np.int32,
|
||||
fill_value=self.dictionary.pad(),
|
||||
)
|
||||
s2s_source_ids = np.full(
|
||||
shape=(batch_size, s2s_source_max_length),
|
||||
dtype=np.int32,
|
||||
fill_value=self.dictionary.pad(),
|
||||
)
|
||||
s2s_target_ids = np.full(
|
||||
shape=(batch_size, s2s_target_max_length - 1),
|
||||
dtype=np.int32,
|
||||
fill_value=self.dictionary.pad(),
|
||||
)
|
||||
s2s_prev_input_ids = np.full(
|
||||
shape=(batch_size, s2s_target_max_length - 1),
|
||||
dtype=np.int32,
|
||||
fill_value=self.dictionary.pad(),
|
||||
)
|
||||
|
||||
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
||||
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
||||
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
|
||||
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
||||
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
||||
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
||||
for i, (
|
||||
mlm_input_ids,
|
||||
mlm_label_ids,
|
||||
s2s_input_ids,
|
||||
s2s_label_ids,
|
||||
) in enumerate(batch):
|
||||
mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
|
||||
mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
|
||||
s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
|
||||
s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
|
||||
s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]
|
||||
|
||||
ret_batch = {
|
||||
'net_input': {
|
||||
'src_tokens': mlm_source_ids.astype(np.int64),
|
||||
"net_input": {
|
||||
"src_tokens": mlm_source_ids.astype(np.int64),
|
||||
},
|
||||
'target': mlm_target_ids.astype(np.int64),
|
||||
'nsentences': batch_size,
|
||||
'ntokens': sum([len(x[0]) for x in batch]),
|
||||
"target": mlm_target_ids.astype(np.int64),
|
||||
"nsentences": batch_size,
|
||||
"ntokens": sum([len(x[0]) for x in batch]),
|
||||
}
|
||||
|
||||
return ret_batch
|
||||
|
||||
padded_batches = iterators.MapIterator(
|
||||
batches, collate
|
||||
)
|
||||
padded_batches = iterators.MapIterator(batches, collate)
|
||||
|
||||
return padded_batches
|
||||
|
||||
|
@ -221,7 +246,6 @@ class MLMLoader(BaseBatchGen):
|
|||
return nonmasked_tokens, masked_tokens
|
||||
|
||||
def _span_corruption(self, _random, doc):
|
||||
|
||||
def mask_tokens(i):
|
||||
return f"<mask_{i}>"
|
||||
|
||||
|
@ -237,7 +261,9 @@ class MLMLoader(BaseBatchGen):
|
|||
else:
|
||||
possible_split_positions = list(range(1, noise_tokens_num))
|
||||
_random.shuffle(possible_split_positions)
|
||||
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
||||
noise_split_positions = sorted(
|
||||
possible_split_positions[: noise_spans_num - 1]
|
||||
)
|
||||
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
||||
|
||||
possible_insert_positions = list(range(nonnoise_tokens_num))
|
||||
|
@ -248,7 +274,7 @@ class MLMLoader(BaseBatchGen):
|
|||
last_end = 0
|
||||
for i in range(noise_spans_num):
|
||||
start_pos = noise_insert_positions[i] + noise_split_positions[i]
|
||||
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
|
||||
end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
|
||||
mask_id = self.dictionary.indices[mask_tokens(i)]
|
||||
|
||||
if getattr(self.args, "remove_target_sentinel", False):
|
||||
|
@ -273,23 +299,25 @@ class MLMLoader(BaseBatchGen):
|
|||
file_path = os.path.join(self.data_dir, source_file)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
print('| file {} not exists'.format(file_path), flush=True)
|
||||
print("| file {} not exists".format(file_path), flush=True)
|
||||
return iter([]) # skip bad file
|
||||
|
||||
with open(file_path, 'r', encoding='utf8') as f:
|
||||
lines = f.read().strip().split('\n')
|
||||
with open(file_path, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip().split("\n")
|
||||
|
||||
doc = [self.dictionary.bos()]
|
||||
for line in lines:
|
||||
if line == "":
|
||||
if self.sample_break_mode == 'complete_doc':
|
||||
if self.sample_break_mode == "complete_doc":
|
||||
# data.append(doc)
|
||||
yield doc
|
||||
doc = [self.dictionary.bos()]
|
||||
continue
|
||||
|
||||
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
||||
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
||||
tokenized_id = [
|
||||
self.dictionary.index(token) for token in tokenized_line
|
||||
] + [self.dictionary.eos_index]
|
||||
|
||||
if len(tokenized_id) > self.tokens_per_sample:
|
||||
continue
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import numpy as np
|
||||
import collections
|
||||
from random import Random
|
||||
from typing import Dict, Iterable, Optional
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
from infinibatch import iterators
|
||||
|
||||
|
||||
|
@ -17,7 +18,9 @@ def apply_to_sample(f, sample):
|
|||
return f(x)
|
||||
elif isinstance(x, collections.OrderedDict):
|
||||
# OrderedDict has attributes that needs to be preserved
|
||||
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
|
||||
od = collections.OrderedDict(
|
||||
(key, _apply(value)) for key, value in x.items()
|
||||
)
|
||||
od.__dict__ = x.__dict__
|
||||
return od
|
||||
elif isinstance(x, dict):
|
||||
|
@ -40,14 +43,15 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
|||
self.setstate(None)
|
||||
|
||||
def getstate(self) -> Dict:
|
||||
return {'num_items_yielded': self._num_items_yielded}
|
||||
return {"num_items_yielded": self._num_items_yielded}
|
||||
|
||||
def setstate(self, checkpoint: Optional[Dict]):
|
||||
self._iterator = iter(self._input_iterable)
|
||||
self._num_items_yielded = iterators._advance_iterator(
|
||||
self._iterator,
|
||||
checkpoint['num_items_yielded']
|
||||
) if checkpoint is not None else 0
|
||||
self._num_items_yielded = (
|
||||
iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
|
||||
if checkpoint is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
item = next(self._iterator)
|
||||
|
@ -73,7 +77,9 @@ class WeightIterator(object):
|
|||
|
||||
def setstate(self, checkpoint):
|
||||
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||||
self._random = None # this will trigger the lazy initialization in self.__next__
|
||||
self._random = (
|
||||
None # this will trigger the lazy initialization in self.__next__
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
if self._random is None:
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
import json
|
||||
from omegaconf import MISSING, II
|
||||
|
||||
import sentencepiece as spm
|
||||
from fairseq import utils
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from omegaconf import II, MISSING
|
||||
|
||||
from .data.mlm_loader import MLMLoader
|
||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
||||
import sentencepiece as spm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -109,21 +111,16 @@ class PretrainingConfig(FairseqDataclass):
|
|||
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
|
||||
spm_model: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": "sentencepice model to tokenize the data"
|
||||
},
|
||||
metadata={"help": "sentencepice model to tokenize the data"},
|
||||
)
|
||||
dict_file: str = field(
|
||||
default="",
|
||||
metadata={
|
||||
"help": ""
|
||||
},
|
||||
metadata={"help": ""},
|
||||
)
|
||||
|
||||
|
||||
@register_task("pretraining", dataclass=PretrainingConfig)
|
||||
class PLMTask(FairseqTask):
|
||||
|
||||
def __init__(self, cfg, dictionary, tokenizer):
|
||||
super().__init__(cfg)
|
||||
self.cfg = cfg
|
||||
|
@ -156,9 +153,9 @@ class PLMTask(FairseqTask):
|
|||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
self.datasets[split] = {
|
||||
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
|
||||
'data_dir': self.cfg.data,
|
||||
'shuffle': True if split == 'train' else False,
|
||||
"data": json.load(open(f"{self.cfg.data}/json/{split}.json")),
|
||||
"data_dir": self.cfg.data,
|
||||
"shuffle": True if split == "train" else False,
|
||||
}
|
||||
self.datasets[split] = Namespace(**self.datasets[split])
|
||||
|
||||
|
@ -185,18 +182,18 @@ class PLMTask(FairseqTask):
|
|||
disable_iterator_cache=False,
|
||||
):
|
||||
return MLMLoader(
|
||||
self.cfg,
|
||||
dataset,
|
||||
self.dictionary,
|
||||
self.tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=ignore_invalid_inputs,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
self.cfg,
|
||||
dataset,
|
||||
self.dictionary,
|
||||
self.tokenizer,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=ignore_invalid_inputs,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
|
||||
@property
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# flake8: noqa
|
||||
import models
|
||||
import tasks
|
||||
|
||||
from fairseq_cli.train import cli_main
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
||||
def clip_grad_norm_(
|
||||
params, max_norm, moe_expert_count, aggregate_norm_fn=None
|
||||
) -> torch.Tensor:
|
||||
def grad_exists(p):
|
||||
return p is not None and getattr(p, "grad", None) is not None
|
||||
|
||||
if isinstance(params, torch.Tensor):
|
||||
params = [params]
|
||||
params = list(params)
|
||||
|
@ -59,7 +63,9 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
|
|||
for split_grads in [expert_grads, sharded_grads]:
|
||||
if len(split_grads) == 0:
|
||||
continue
|
||||
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]))
|
||||
split_norm = torch.norm(
|
||||
torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
|
||||
)
|
||||
if dist.is_initialized():
|
||||
split_norm.pow_(2)
|
||||
dist.all_reduce(split_norm)
|
||||
|
|
15
setup.py
15
setup.py
|
@ -2,6 +2,7 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from io import open
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
|
@ -10,19 +11,15 @@ setup(
|
|||
author="TorchScale Team",
|
||||
author_email="Shuming.Ma@microsoft.com",
|
||||
description="Transformers at any scale",
|
||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="Transformers at any scale",
|
||||
license="MIT",
|
||||
url="https://github.com/msranlp/torchscale",
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
||||
"tests.*", "tests"]),
|
||||
install_requires=['apex',
|
||||
'torch>=1.8',
|
||||
'fairscale==0.4.0',
|
||||
'timm==0.4.12'],
|
||||
python_requires='>=3.8.0',
|
||||
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||
install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
||||
python_requires=">=3.8.0",
|
||||
classifiers=[
|
||||
'Programming Language :: Python :: 3',
|
||||
"Programming Language :: Python :: 3",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torchscale.architecture.config import DecoderConfig
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
import torch
|
||||
|
||||
testcases = [
|
||||
{},
|
||||
|
@ -20,7 +21,7 @@ testcases = [
|
|||
{"multiway": True},
|
||||
{"share_decoder_input_output_embed": True},
|
||||
{"checkpoint_activations": True},
|
||||
{"fsdp": True}
|
||||
{"fsdp": True},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torchscale.architecture.config import EncoderConfig
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
import torch
|
||||
|
||||
testcases = [
|
||||
{},
|
||||
|
@ -20,7 +21,7 @@ testcases = [
|
|||
{"multiway": True},
|
||||
{"share_encoder_input_output_embed": True},
|
||||
{"checkpoint_activations": True},
|
||||
{"fsdp": True}
|
||||
{"fsdp": True},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from torchscale.architecture.config import EncoderDecoderConfig
|
||||
from torchscale.architecture.encoder_decoder import EncoderDecoder
|
||||
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding
|
||||
import torch
|
||||
from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
|
||||
|
||||
testcases = [
|
||||
{},
|
||||
|
@ -16,13 +17,18 @@ testcases = [
|
|||
{"no_scale_embedding": False},
|
||||
{"layernorm_embedding": True},
|
||||
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False},
|
||||
{
|
||||
"deepnorm": True,
|
||||
"subln": False,
|
||||
"encoder_normalize_before": False,
|
||||
"decoder_normalize_before": False,
|
||||
},
|
||||
{"bert_init": True},
|
||||
{"multiway": True},
|
||||
{"share_decoder_input_output_embed": True},
|
||||
{"share_all_embeddings": True},
|
||||
{"checkpoint_activations": True},
|
||||
{"fsdp": True}
|
||||
{"fsdp": True},
|
||||
]
|
||||
|
||||
|
||||
|
@ -33,8 +39,12 @@ def test_decoder(args):
|
|||
config,
|
||||
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
||||
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
||||
encoder_embed_positions=PositionalEmbedding(
|
||||
config.max_source_positions, config.encoder_embed_dim
|
||||
),
|
||||
decoder_embed_positions=PositionalEmbedding(
|
||||
config.max_target_positions, config.decoder_embed_dim
|
||||
),
|
||||
)
|
||||
|
||||
src_tokens = torch.ones(2, 20).long()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
|
||||
class EncoderConfig(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||
|
@ -19,9 +20,13 @@ class EncoderConfig(object):
|
|||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||
"moe_eval_capacity_token_fraction", 0.25
|
||||
)
|
||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||
"moe_normalize_gate_prob_before_dropping", False
|
||||
)
|
||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||
|
@ -29,7 +34,9 @@ class EncoderConfig(object):
|
|||
self.subln = kwargs.pop("subln", True)
|
||||
self.bert_init = kwargs.pop("bert_init", False)
|
||||
self.multiway = kwargs.pop("multiway", False)
|
||||
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
|
||||
self.share_encoder_input_output_embed = kwargs.pop(
|
||||
"share_encoder_input_output_embed", False
|
||||
)
|
||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
# Text
|
||||
|
@ -78,9 +85,13 @@ class DecoderConfig(object):
|
|||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||
"moe_eval_capacity_token_fraction", 0.25
|
||||
)
|
||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||
"moe_normalize_gate_prob_before_dropping", False
|
||||
)
|
||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||
|
@ -88,7 +99,9 @@ class DecoderConfig(object):
|
|||
self.subln = kwargs.pop("subln", True)
|
||||
self.bert_init = kwargs.pop("bert_init", False)
|
||||
self.multiway = kwargs.pop("multiway", False)
|
||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
||||
self.share_decoder_input_output_embed = kwargs.pop(
|
||||
"share_decoder_input_output_embed", False
|
||||
)
|
||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
# Text
|
||||
|
@ -138,9 +151,13 @@ class EncoderDecoderConfig(object):
|
|||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
||||
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||
"moe_eval_capacity_token_fraction", 0.25
|
||||
)
|
||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||
"moe_normalize_gate_prob_before_dropping", False
|
||||
)
|
||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||
|
@ -149,7 +166,9 @@ class EncoderDecoderConfig(object):
|
|||
self.bert_init = kwargs.pop("bert_init", False)
|
||||
self.multiway = kwargs.pop("multiway", False)
|
||||
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
||||
self.share_decoder_input_output_embed = kwargs.pop(
|
||||
"share_decoder_input_output_embed", False
|
||||
)
|
||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||
|
|
|
@ -2,22 +2,23 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||
from torchscale.component.multihead_attention import MultiheadAttention
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
|
@ -31,7 +32,9 @@ class DecoderLayer(nn.Module):
|
|||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
if args.drop_path_rate > 0:
|
||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
|
||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
|
||||
depth
|
||||
]
|
||||
self.drop_path = DropPath(drop_path_prob)
|
||||
else:
|
||||
self.drop_path = None
|
||||
|
@ -206,7 +209,6 @@ class DecoderLayer(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
|
@ -228,7 +230,11 @@ class Decoder(nn.Module):
|
|||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = embed_positions
|
||||
|
||||
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
|
||||
if (
|
||||
output_projection is None
|
||||
and not args.no_output_layer
|
||||
and args.vocab_size > 0
|
||||
):
|
||||
self.output_projection = self.build_output_projection(args)
|
||||
else:
|
||||
self.output_projection = output_projection
|
||||
|
@ -286,7 +292,12 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.div_(init_scale)
|
||||
|
||||
if args.subln:
|
||||
|
@ -295,9 +306,14 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if 'encoder_attn' in name:
|
||||
if "encoder_attn" in name:
|
||||
continue
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
def build_output_projection(
|
||||
|
@ -316,16 +332,12 @@ class Decoder(nn.Module):
|
|||
args.decoder_embed_dim, args.vocab_size, bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
||||
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
||||
)
|
||||
return output_projection
|
||||
|
||||
def build_decoder_layer(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False
|
||||
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
|
||||
):
|
||||
layer = DecoderLayer(
|
||||
args,
|
||||
|
@ -347,7 +359,9 @@ class Decoder(nn.Module):
|
|||
):
|
||||
positions = None
|
||||
if self.embed_positions is not None:
|
||||
positions = self.embed_positions(tokens, incremental_state=incremental_state)
|
||||
positions = self.embed_positions(
|
||||
tokens, incremental_state=incremental_state
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
tokens = tokens[:, -1:]
|
||||
|
@ -381,7 +395,9 @@ class Decoder(nn.Module):
|
|||
**kwargs
|
||||
):
|
||||
# embed tokens and positions
|
||||
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
||||
x, _ = self.forward_embedding(
|
||||
prev_output_tokens, token_embeddings, incremental_state
|
||||
)
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# relative postion
|
||||
|
@ -389,9 +405,7 @@ class Decoder(nn.Module):
|
|||
slen = prev_output_tokens.size(1)
|
||||
if self.self_attn_relative_position is not None:
|
||||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||
batch_size=x.size(1),
|
||||
qlen=slen,
|
||||
klen=slen
|
||||
batch_size=x.size(1), qlen=slen, klen=slen
|
||||
)
|
||||
if incremental_state is not None:
|
||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
|
||||
|
@ -416,7 +430,11 @@ class Decoder(nn.Module):
|
|||
for idx, layer in enumerate(self.layers):
|
||||
if incremental_state is None:
|
||||
self_attn_mask = torch.triu(
|
||||
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1
|
||||
torch.zeros([x.size(0), x.size(0)])
|
||||
.float()
|
||||
.fill_(float("-inf"))
|
||||
.type_as(x),
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self_attn_mask = None
|
||||
|
@ -426,7 +444,9 @@ class Decoder(nn.Module):
|
|||
x, layer_attn, _, l_aux_i = layer(
|
||||
x,
|
||||
encoder_out["encoder_out"] if encoder_out is not None else None,
|
||||
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
|
||||
encoder_out["encoder_padding_mask"]
|
||||
if encoder_out is not None
|
||||
else None,
|
||||
incremental_state[idx] if incremental_state is not None else None,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
|
@ -444,7 +464,11 @@ class Decoder(nn.Module):
|
|||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]}
|
||||
return x, {
|
||||
"inner_states": inner_states,
|
||||
"l_aux": l_aux,
|
||||
"attn": [layer_attn.mean(dim=0)],
|
||||
}
|
||||
|
||||
def output_layer(self, features):
|
||||
return self.output_projection(features)
|
||||
|
|
|
@ -2,30 +2,25 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from fairscale.nn import checkpoint_wrapper, wrap
|
||||
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||
from torchscale.component.multihead_attention import MultiheadAttention
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
|
||||
from torchscale.component.droppath import DropPath
|
||||
from torchscale.architecture.utils import init_bert_params
|
||||
from torchscale.component.multiway_network import MultiwayWrapper, set_split_position
|
||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False
|
||||
):
|
||||
def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
|
@ -34,7 +29,9 @@ class EncoderLayer(nn.Module):
|
|||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||
|
||||
if args.drop_path_rate > 0:
|
||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
|
||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[
|
||||
depth
|
||||
]
|
||||
self.drop_path = DropPath(drop_path_prob)
|
||||
else:
|
||||
self.drop_path = None
|
||||
|
@ -49,7 +46,7 @@ class EncoderLayer(nn.Module):
|
|||
self.build_ffn(
|
||||
self.embed_dim,
|
||||
self.args,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert not self.args.multiway
|
||||
|
@ -77,7 +74,12 @@ class EncoderLayer(nn.Module):
|
|||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
|
||||
self.alpha = (
|
||||
math.pow(
|
||||
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
|
||||
)
|
||||
* 0.81
|
||||
)
|
||||
else:
|
||||
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
||||
else:
|
||||
|
@ -107,13 +109,7 @@ class EncoderLayer(nn.Module):
|
|||
def residual_connection(self, x, residual):
|
||||
return residual * self.alpha + x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_padding_mask,
|
||||
attn_mask=None,
|
||||
rel_pos=None
|
||||
):
|
||||
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
||||
|
||||
|
@ -158,7 +154,6 @@ class EncoderLayer(nn.Module):
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
|
@ -179,13 +174,20 @@ class Encoder(nn.Module):
|
|||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = embed_positions
|
||||
|
||||
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
|
||||
if (
|
||||
output_projection is None
|
||||
and not is_encoder_decoder
|
||||
and not args.no_output_layer
|
||||
and args.vocab_size > 0
|
||||
):
|
||||
self.output_projection = self.build_output_projection(args)
|
||||
else:
|
||||
self.output_projection = output_projection
|
||||
|
||||
if args.layernorm_embedding:
|
||||
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1)
|
||||
self.layernorm_embedding = MultiwayWrapper(
|
||||
args, LayerNorm(embed_dim), dim=1
|
||||
)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
|
@ -199,7 +201,7 @@ class Encoder(nn.Module):
|
|||
args,
|
||||
depth=i,
|
||||
is_moe_layer=is_moe_layer,
|
||||
is_encoder_decoder=is_encoder_decoder
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
)
|
||||
)
|
||||
self.num_layers = len(self.layers)
|
||||
|
@ -223,20 +225,39 @@ class Encoder(nn.Module):
|
|||
|
||||
if args.deepnorm:
|
||||
if is_encoder_decoder:
|
||||
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
|
||||
init_scale = (
|
||||
math.pow(
|
||||
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
|
||||
)
|
||||
/ 1.15
|
||||
)
|
||||
else:
|
||||
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
||||
for name, p in self.named_parameters():
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.div_(init_scale)
|
||||
|
||||
if args.subln:
|
||||
if is_encoder_decoder:
|
||||
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
|
||||
init_scale = math.sqrt(
|
||||
math.log(3 * args.decoder_layers)
|
||||
* math.log(2 * args.encoder_layers)
|
||||
/ 3
|
||||
)
|
||||
else:
|
||||
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
||||
for name, p in self.named_parameters():
|
||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
||||
if (
|
||||
"fc1" in name
|
||||
or "fc2" in name
|
||||
or "out_proj" in name
|
||||
or "v_proj" in name
|
||||
):
|
||||
p.data.mul_(init_scale)
|
||||
|
||||
def build_output_projection(
|
||||
|
@ -244,7 +265,7 @@ class Encoder(nn.Module):
|
|||
args,
|
||||
):
|
||||
if args.share_encoder_input_output_embed:
|
||||
assert args.encoder_embedding_type == 'language'
|
||||
assert args.encoder_embedding_type == "language"
|
||||
output_projection = torch.nn.Linear(
|
||||
self.embed_tokens.weight.shape[1],
|
||||
self.embed_tokens.weight.shape[0],
|
||||
|
@ -256,22 +277,18 @@ class Encoder(nn.Module):
|
|||
args.encoder_embed_dim, args.vocab_size, bias=False
|
||||
)
|
||||
torch.nn.init.normal_(
|
||||
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5
|
||||
output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5
|
||||
)
|
||||
return output_projection
|
||||
|
||||
def build_encoder_layer(
|
||||
self,
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=False,
|
||||
is_encoder_decoder=False
|
||||
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
|
||||
):
|
||||
layer = EncoderLayer(
|
||||
args,
|
||||
depth,
|
||||
is_moe_layer=is_moe_layer,
|
||||
is_encoder_decoder=is_encoder_decoder
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
)
|
||||
if args.checkpoint_activations:
|
||||
layer = checkpoint_wrapper(layer)
|
||||
|
@ -312,13 +329,12 @@ class Encoder(nn.Module):
|
|||
if encoder_padding_mask is None:
|
||||
if src_tokens is not None:
|
||||
encoder_padding_mask = torch.zeros_like(
|
||||
src_tokens,
|
||||
device=src_tokens.device
|
||||
src_tokens, device=src_tokens.device
|
||||
).bool()
|
||||
else:
|
||||
encoder_padding_mask = torch.zeros(
|
||||
[token_embeddings.size(0), token_embeddings.size(1)],
|
||||
device=token_embeddings.device
|
||||
device=token_embeddings.device,
|
||||
).bool()
|
||||
|
||||
if multiway_split_position is not None:
|
||||
|
@ -338,16 +354,13 @@ class Encoder(nn.Module):
|
|||
rel_pos_bias = None
|
||||
if self.relative_position is not None:
|
||||
rel_pos_bias = self.relative_position(
|
||||
batch_size=x.size(1),
|
||||
qlen=x.size(0),
|
||||
klen=x.size(0)
|
||||
batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
|
||||
)
|
||||
|
||||
l_aux = []
|
||||
for layer in self.layers:
|
||||
x, l_aux_i = layer(
|
||||
x, encoder_padding_mask=encoder_padding_mask,
|
||||
rel_pos=rel_pos_bias
|
||||
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
|
||||
)
|
||||
if return_all_hiddens:
|
||||
assert encoder_states is not None
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch.nn as nn
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
|
||||
from torchscale.architecture.decoder import Decoder
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
|
||||
|
||||
class EncoderDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
|
@ -51,10 +51,7 @@ class EncoderDecoder(nn.Module):
|
|||
features_only=False,
|
||||
**kwargs
|
||||
):
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from torchscale.component.multihead_attention import MultiheadAttention
|
||||
from torchscale.component.multiway_network import MultiwayNetwork
|
||||
|
||||
|
||||
def init_bert_params(module):
|
||||
|
||||
def normal_(data):
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# Copyright (c) 2022 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from timm.models.layers import drop_path
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import drop_path
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
@ -16,4 +16,4 @@ class DropPath(nn.Module):
|
|||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self):
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
|
|
@ -7,22 +7,12 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
class VisionLanguageEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_embed,
|
||||
vision_embed
|
||||
):
|
||||
def __init__(self, text_embed, vision_embed):
|
||||
super().__init__()
|
||||
self.text_embed = text_embed
|
||||
self.vision_embed = vision_embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
textual_tokens,
|
||||
visual_tokens,
|
||||
**kwargs
|
||||
):
|
||||
def forward(self, textual_tokens, visual_tokens, **kwargs):
|
||||
if textual_tokens is None:
|
||||
return self.vision_embed(visual_tokens)
|
||||
|
||||
|
@ -36,8 +26,8 @@ class VisionLanguageEmbedding(nn.Module):
|
|||
|
||||
|
||||
class VisionEmbedding(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
|
@ -45,7 +35,7 @@ class VisionEmbedding(nn.Module):
|
|||
in_chans=3,
|
||||
embed_dim=768,
|
||||
contain_mask_token=False,
|
||||
prepend_cls_token=False
|
||||
prepend_cls_token=False,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = (img_size, img_size)
|
||||
|
@ -56,7 +46,9 @@ class VisionEmbedding(nn.Module):
|
|||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
if contain_mask_token:
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
@ -68,15 +60,11 @@ class VisionEmbedding(nn.Module):
|
|||
else:
|
||||
self.cls_token = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
masked_position=None,
|
||||
**kwargs
|
||||
):
|
||||
def forward(self, x, masked_position=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
@ -88,21 +76,21 @@ class VisionEmbedding(nn.Module):
|
|||
x = x * (1 - w) + mask_token * w
|
||||
|
||||
if self.cls_token is not None:
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TextEmbedding(nn.Embedding):
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
|
||||
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
||||
class PositionalEmbedding(nn.Embedding):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
|
@ -111,7 +99,9 @@ class PositionalEmbedding(nn.Embedding):
|
|||
):
|
||||
if positions is None:
|
||||
# being consistent with Fairseq, which starts from 2.
|
||||
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0)
|
||||
positions = (
|
||||
torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
|
||||
)
|
||||
return F.embedding(
|
||||
positions,
|
||||
self.weight,
|
||||
|
|
|
@ -35,13 +35,19 @@ class set_torch_seed(object):
|
|||
|
||||
|
||||
def make_experts(args, embed_dim, expert_ffn_dim):
|
||||
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
|
||||
world_size = (
|
||||
1
|
||||
if not torch.distributed.is_initialized()
|
||||
else torch.distributed.get_world_size()
|
||||
)
|
||||
expert_list = []
|
||||
ddp_rank = args.ddp_rank
|
||||
start_seed = torch.randint(1000000, (1,)).item()
|
||||
# at least as many experts than gpus
|
||||
if args.moe_expert_count >= world_size:
|
||||
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}'
|
||||
assert (
|
||||
args.moe_expert_count % world_size == 0
|
||||
), f"{args.moe_expert_count}, {world_size}"
|
||||
local_moe_expert_count = args.moe_expert_count // world_size
|
||||
for i in range(local_moe_expert_count):
|
||||
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
||||
|
@ -52,11 +58,13 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.subln
|
||||
args.subln,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}'
|
||||
assert (
|
||||
world_size % args.moe_expert_count == 0
|
||||
), f"{world_size}, {args.moe_expert_count}"
|
||||
|
||||
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
||||
expert_list.append(
|
||||
|
@ -66,7 +74,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
|||
args.activation_fn,
|
||||
args.dropout,
|
||||
args.activation_dropout,
|
||||
args.subln
|
||||
args.subln,
|
||||
)
|
||||
)
|
||||
experts = nn.ModuleList(expert_list)
|
||||
|
@ -83,7 +91,6 @@ def get_activation_fn(activation):
|
|||
|
||||
|
||||
class FeedForwardNetwork(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
|
@ -91,12 +98,14 @@ class FeedForwardNetwork(nn.Module):
|
|||
activation_fn,
|
||||
dropout,
|
||||
activation_dropout,
|
||||
subln=False
|
||||
subln=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True)
|
||||
self.activation_dropout_module = torch.nn.Dropout(
|
||||
activation_dropout, inplace=True
|
||||
)
|
||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
||||
|
|
|
@ -2,15 +2,16 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||
from torch import nn
|
||||
|
||||
from .multiway_network import MultiwayWrapper
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
|
@ -25,7 +26,7 @@ class MultiheadAttention(nn.Module):
|
|||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.self_attention = self_attention
|
||||
self.encoder_decoder_attention = encoder_decoder_attention
|
||||
|
@ -34,8 +35,14 @@ class MultiheadAttention(nn.Module):
|
|||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None
|
||||
self.out_proj = MultiwayWrapper(
|
||||
args, nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
)
|
||||
self.inner_attn_ln = (
|
||||
MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||
if subln and self.self_attention
|
||||
else None
|
||||
)
|
||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||
|
||||
def reset_parameters(self):
|
||||
|
@ -76,12 +83,20 @@ class MultiheadAttention(nn.Module):
|
|||
|
||||
if incremental_state is not None:
|
||||
if "prev_key" in incremental_state:
|
||||
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
||||
prev_key = incremental_state["prev_key"].view(
|
||||
bsz * self.num_heads, -1, self.head_dim
|
||||
)
|
||||
prev_value = incremental_state["prev_value"].view(
|
||||
bsz * self.num_heads, -1, self.head_dim
|
||||
)
|
||||
k = torch.cat([prev_key, k], dim=1)
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||
incremental_state["prev_key"] = k.view(
|
||||
bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
incremental_state["prev_value"] = v.view(
|
||||
bsz, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
src_len = k.size(1)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
|
@ -103,7 +118,9 @@ class MultiheadAttention(nn.Module):
|
|||
rel_pos = rel_pos.view(attn_weights.size())
|
||||
attn_weights = attn_weights + rel_pos
|
||||
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||
attn_weights
|
||||
)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
attn = torch.bmm(attn_probs, v)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -13,16 +14,14 @@ def MultiwayWrapper(args, module, dim=0):
|
|||
|
||||
|
||||
def set_split_position(position):
|
||||
|
||||
def apply_fn(module):
|
||||
if hasattr(module, 'split_position'):
|
||||
if hasattr(module, "split_position"):
|
||||
module.split_position = position
|
||||
|
||||
return apply_fn
|
||||
|
||||
|
||||
class MultiwayNetwork(nn.Module):
|
||||
|
||||
def __init__(self, module, dim=0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
@ -36,7 +35,11 @@ class MultiwayNetwork(nn.Module):
|
|||
return self.A(x, **kwargs)
|
||||
if self.split_position == 0:
|
||||
return self.B(x, **kwargs)
|
||||
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim)
|
||||
x1, x2 = torch.split(
|
||||
x,
|
||||
[self.split_position, x.size(self.dim) - self.split_position],
|
||||
dim=self.dim,
|
||||
)
|
||||
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
||||
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
||||
return torch.cat([y1, y2], dim=self.dim)
|
||||
|
|
|
@ -2,17 +2,14 @@
|
|||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
n_heads=12
|
||||
self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
|
||||
):
|
||||
super().__init__()
|
||||
self.bidirectional = bidirectional
|
||||
|
@ -23,10 +20,7 @@ class RelativePositionBias(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128
|
||||
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||
):
|
||||
ret = 0
|
||||
n = -relative_position
|
||||
|
@ -41,24 +35,28 @@ class RelativePositionBias(nn.Module):
|
|||
is_small = n < max_exact
|
||||
|
||||
val_if_large = max_exact + (
|
||||
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
torch.log(n.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
val_if_large = torch.min(
|
||||
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
ret += torch.where(is_small, n, val_if_large)
|
||||
return ret
|
||||
|
||||
def compute_bias(
|
||||
self,
|
||||
qlen,
|
||||
klen,
|
||||
step=None
|
||||
):
|
||||
def compute_bias(self, qlen, klen, step=None):
|
||||
step = 0 if step is None else step
|
||||
context_position = torch.arange(step, step + qlen, dtype=torch.long,
|
||||
device=self.relative_attention_bias.weight.device)[:, None]
|
||||
memory_position = torch.arange(klen, dtype=torch.long,
|
||||
device=self.relative_attention_bias.weight.device)[None, :]
|
||||
context_position = torch.arange(
|
||||
step,
|
||||
step + qlen,
|
||||
dtype=torch.long,
|
||||
device=self.relative_attention_bias.weight.device,
|
||||
)[:, None]
|
||||
memory_position = torch.arange(
|
||||
klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
|
||||
)[None, :]
|
||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||
|
||||
rp_bucket = self._relative_position_bucket(
|
||||
|
@ -67,16 +65,18 @@ class RelativePositionBias(nn.Module):
|
|||
num_buckets=self.num_buckets,
|
||||
)
|
||||
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
values = self.relative_attention_bias(
|
||||
rp_bucket
|
||||
) # shape (qlen, klen, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(
|
||||
0
|
||||
) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size,
|
||||
qlen,
|
||||
klen,
|
||||
step=None
|
||||
):
|
||||
def forward(self, batch_size, qlen, klen, step=None):
|
||||
# shape (batch * num_heads, qlen, klen)
|
||||
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
|
||||
return (
|
||||
self.compute_bias(qlen, klen, step)
|
||||
.repeat(batch_size, 1, 1, 1)
|
||||
.view(-1, qlen, klen)
|
||||
)
|
||||
|
|
|
@ -18,9 +18,9 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.nn import Module, ModuleList
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.modules.moe import MOELayer
|
||||
|
||||
has_fairseq = True
|
||||
Base = MOELayer
|
||||
except ModuleNotFoundError:
|
||||
|
@ -81,8 +81,10 @@ def get_moe_group(moe_expert_count):
|
|||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||
for i in range(moe_expert_count)]
|
||||
moe_groups = [
|
||||
[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||
for i in range(moe_expert_count)
|
||||
]
|
||||
|
||||
get_moe_group._moe_group_idx = moe_groups
|
||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||
|
@ -105,11 +107,15 @@ def get_all2all_group(moe_expert_count):
|
|||
else:
|
||||
assert world_size % moe_expert_count == 0
|
||||
ranks_per_group = world_size // moe_expert_count
|
||||
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||
for i in range(ranks_per_group)]
|
||||
all2all_groups = [
|
||||
[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||
for i in range(ranks_per_group)
|
||||
]
|
||||
|
||||
get_all2all_group._all2all_group_idx = all2all_groups
|
||||
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups]
|
||||
get_all2all_group._all2all_groups = [
|
||||
dist.new_group(g) for g in all2all_groups
|
||||
]
|
||||
|
||||
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||
return get_all2all_group._all2all_groups[my_group_idx]
|
||||
|
@ -133,12 +139,7 @@ class MOELayer(Base):
|
|||
expert network
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gate,
|
||||
experts,
|
||||
args
|
||||
):
|
||||
def __init__(self, gate, experts, args):
|
||||
if has_fairseq:
|
||||
super(Base, self).__init__()
|
||||
else:
|
||||
|
@ -163,9 +164,13 @@ class MOELayer(Base):
|
|||
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
|
||||
assert len(input) == 1, "only single input Tensor supported"
|
||||
input = input[0]
|
||||
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
|
||||
assert (
|
||||
len(input.shape) == 3
|
||||
), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
|
||||
if input_padding_mask is not None:
|
||||
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken"
|
||||
assert (
|
||||
len(input_padding_mask.shape) == 2
|
||||
), "input Tensor must have dimensions: (s)equence, (t)oken"
|
||||
assert input_padding_mask.shape[0] == input.shape[0]
|
||||
assert input_padding_mask.shape[1] == input.shape[1]
|
||||
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
|
||||
|
@ -174,81 +179,120 @@ class MOELayer(Base):
|
|||
d_model = input.shape[2]
|
||||
# Pad to expected batch size
|
||||
input_shape = list(input.shape)
|
||||
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0)
|
||||
expected_bsz = (
|
||||
getattr(self.args, "batch_size", 0)
|
||||
if self.training
|
||||
else getattr(self.args, "batch_size_valid", 0)
|
||||
)
|
||||
# This indicates that --batch-size or --max-sentences is not specified
|
||||
if expected_bsz is None:
|
||||
expected_bsz = 0
|
||||
# Note: Padding is not necessary at generation time at present
|
||||
# because all DDP workers process the same batch. Also, batch size at generation time
|
||||
# can be different from that present in the checkpoint state
|
||||
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz:
|
||||
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})")
|
||||
if (
|
||||
not self.in_generation
|
||||
and expected_bsz != 0
|
||||
and input_shape[0] != expected_bsz
|
||||
):
|
||||
logger.warning(
|
||||
f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})"
|
||||
)
|
||||
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
|
||||
padded_input = torch.zeros(
|
||||
(expected_bsz, input_shape[1], input_shape[2]),
|
||||
dtype=input.dtype, layout=input.layout, device=input.device)
|
||||
padded_input[:input_shape[0], :, :] = input
|
||||
dtype=input.dtype,
|
||||
layout=input.layout,
|
||||
device=input.device,
|
||||
)
|
||||
padded_input[: input_shape[0], :, :] = input
|
||||
input = padded_input
|
||||
|
||||
padded_input_padding_mask = torch.ones(
|
||||
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device
|
||||
(
|
||||
expected_bsz,
|
||||
input_shape[1],
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=input.device,
|
||||
)
|
||||
if input_padding_mask is not None:
|
||||
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask
|
||||
padded_input_padding_mask[: input_shape[0], :] = input_padding_mask
|
||||
else:
|
||||
padded_input_padding_mask[:input_shape[0], :] = False
|
||||
padded_input_padding_mask[: input_shape[0], :] = False
|
||||
input_padding_mask = padded_input_padding_mask
|
||||
|
||||
# Reshape into S tokens by dropping sequence dimension.
|
||||
reshaped_input = input.reshape(-1, d_model)
|
||||
reshaped_input_shape = reshaped_input.shape
|
||||
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None
|
||||
reshaped_input_padding_mask = (
|
||||
input_padding_mask.reshape(-1) if input_padding_mask is not None else None
|
||||
)
|
||||
|
||||
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||
# Pro of --max-tokens: more flexible for MT variable sequence lengths
|
||||
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
|
||||
if expected_bsz == 0:
|
||||
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device)
|
||||
expected_dim = reshaped_input_shape[0] * torch.ones(
|
||||
(1,), dtype=torch.long, device=input.device
|
||||
)
|
||||
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
|
||||
expected_dim = int(expected_dim.item())
|
||||
padded_input = torch.zeros(
|
||||
(expected_dim, reshaped_input_shape[1]),
|
||||
dtype=input.dtype, layout=input.layout, device=input.device)
|
||||
padded_input[:reshaped_input_shape[0], :] = reshaped_input
|
||||
dtype=input.dtype,
|
||||
layout=input.layout,
|
||||
device=input.device,
|
||||
)
|
||||
padded_input[: reshaped_input_shape[0], :] = reshaped_input
|
||||
reshaped_input = padded_input
|
||||
|
||||
padded_input_padding_mask = torch.ones(
|
||||
(expected_dim,), dtype=torch.bool, device=padded_input.device
|
||||
)
|
||||
if reshaped_input_padding_mask is not None:
|
||||
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask
|
||||
padded_input_padding_mask[
|
||||
: reshaped_input_shape[0]
|
||||
] = reshaped_input_padding_mask
|
||||
else:
|
||||
padded_input_padding_mask[:reshaped_input_shape[0]] = False
|
||||
padded_input_padding_mask[: reshaped_input_shape[0]] = False
|
||||
reshaped_input_padding_mask = padded_input_padding_mask
|
||||
|
||||
if has_tutel:
|
||||
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask)
|
||||
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(
|
||||
reshaped_input, reshaped_input_padding_mask
|
||||
)
|
||||
S, M = reshaped_input.size(0), reshaped_input.size(1)
|
||||
|
||||
if not hasattr(self, '_tutel_dispatcher'):
|
||||
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
|
||||
if not hasattr(self, "_tutel_dispatcher"):
|
||||
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
|
||||
E, C, M, dispatch_dtype=reshaped_input.dtype
|
||||
)
|
||||
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
|
||||
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
|
||||
else:
|
||||
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask)
|
||||
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(
|
||||
reshaped_input, reshaped_input_padding_mask
|
||||
)
|
||||
|
||||
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S
|
||||
dispatch_mask = dispatch_mask.to(input.dtype).permute(
|
||||
1, 2, 0
|
||||
) # S,E,C -> E,C,S
|
||||
E, C, S = dispatch_mask.size()
|
||||
M = reshaped_input.size(1)
|
||||
assert reshaped_input.size() == (S, M)
|
||||
# einsum("sec,sm->ecm")
|
||||
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M
|
||||
dispatched_input = torch.mm(
|
||||
dispatch_mask.view(E * C, S), reshaped_input
|
||||
) # -> (E*C),M
|
||||
|
||||
if self.all2all_size > 1:
|
||||
dispatched_input = self.all_to_all_wrapper(dispatched_input)
|
||||
|
||||
# Re-shape after all-to-all: ecm -> gecm
|
||||
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model)
|
||||
dispatched_input = dispatched_input.reshape(
|
||||
self.all2all_size, self.num_local_experts, -1, d_model
|
||||
)
|
||||
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
|
||||
expert_outputs = []
|
||||
for chunk, expert in zip(chunks, self.experts):
|
||||
|
@ -259,18 +303,24 @@ class MOELayer(Base):
|
|||
expert_output = self.all_to_all_wrapper(expert_output)
|
||||
|
||||
# Re-shape back: gecm -> ecm
|
||||
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model)
|
||||
expert_output = expert_output.reshape(
|
||||
self.all2all_size * self.num_local_experts, -1, d_model
|
||||
)
|
||||
|
||||
if has_tutel:
|
||||
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M))
|
||||
combined_output = self._tutel_dispatcher.decode(
|
||||
expert_output.view(E * C, M)
|
||||
)
|
||||
else:
|
||||
# einsum("sec,ecm->sm")
|
||||
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M))
|
||||
combined_output = combine_weights.view(S, E * C).mm(
|
||||
expert_output.view(E * C, M)
|
||||
)
|
||||
|
||||
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||
combined_output = combined_output[:reshaped_input_shape[0], :]
|
||||
combined_output = combined_output[: reshaped_input_shape[0], :]
|
||||
combined_output = combined_output.reshape(input.shape)
|
||||
combined_output = combined_output[:input_shape[0], :, :]
|
||||
combined_output = combined_output[: input_shape[0], :, :]
|
||||
|
||||
self.record_all_to_all_stats()
|
||||
|
||||
|
@ -280,7 +330,7 @@ class MOELayer(Base):
|
|||
self.in_generation = True
|
||||
|
||||
def all_to_all_wrapper(self, input: Tensor):
|
||||
dummy_a2a = getattr(self.args, 'dummy_a2a', False)
|
||||
dummy_a2a = getattr(self.args, "dummy_a2a", False)
|
||||
if dummy_a2a:
|
||||
input = input.contiguous()
|
||||
output = input.detach().clone()
|
||||
|
@ -294,13 +344,13 @@ class MOELayer(Base):
|
|||
output = _AllToAll.apply(self.all2all_group, input)
|
||||
cuda_end.record()
|
||||
cpu_end = time.time() * 1000
|
||||
self.a2a_cpu_time_ms += (cpu_end - cpu_start)
|
||||
self.a2a_cpu_time_ms += cpu_end - cpu_start
|
||||
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
|
||||
return output
|
||||
|
||||
def record_all_to_all_stats(self):
|
||||
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
|
||||
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False)
|
||||
record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False)
|
||||
if record_a2a_perf_stats:
|
||||
torch.cuda.synchronize()
|
||||
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# NOTE: This is a mirror of the code in
|
||||
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
||||
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
from .moe_layer import has_tutel, fused_cumsum_sub_one
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from .moe_layer import fused_cumsum_sub_one, has_tutel
|
||||
|
||||
# use a fixed temperature to compute balance loss
|
||||
TEMPERATURE_FOR_L_UAX = 0.07
|
||||
|
@ -65,13 +65,22 @@ def top1gating(
|
|||
indices1_s = torch.argmax(gates, dim=1)
|
||||
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
|
||||
if input_mask is not None and input_mask.any():
|
||||
nonpadding = ~ input_mask
|
||||
nonpadding = ~input_mask
|
||||
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||
|
||||
# for logging (percent of tokens routed to each expert)
|
||||
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||
expert1_hist = (
|
||||
100
|
||||
* torch.histc(
|
||||
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
|
||||
)
|
||||
/ num_tokens
|
||||
)
|
||||
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||
expert1_hist = (
|
||||
torch.sort(expert1_hist, dim=0, descending=True).values
|
||||
+ torch.finfo(torch.float32).tiny
|
||||
)
|
||||
|
||||
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||
|
@ -91,7 +100,21 @@ def top1gating(
|
|||
|
||||
if has_tutel:
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ]
|
||||
return (
|
||||
l_aux,
|
||||
metadata,
|
||||
capacity,
|
||||
num_experts,
|
||||
[
|
||||
indices1_s,
|
||||
],
|
||||
[
|
||||
locations1_s,
|
||||
],
|
||||
[
|
||||
gates1_s,
|
||||
],
|
||||
)
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
mask1 = mask1 * torch.lt(locations1, capacity)
|
||||
|
@ -104,7 +127,8 @@ def top1gating(
|
|||
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
||||
combine1_sec = torch.bmm(
|
||||
# einsum("se,sc->sec")
|
||||
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
||||
gates1.unsqueeze(-1),
|
||||
locations1_sc.to(gates1.dtype).unsqueeze(1),
|
||||
)
|
||||
dispatch_mask = combine1_sec.bool()
|
||||
if use_fp32:
|
||||
|
@ -218,10 +242,10 @@ def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) ->
|
|||
if unsqueeze_indices:
|
||||
indices = indices.unsqueeze(-1)
|
||||
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
|
||||
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype)
|
||||
output.scatter_(
|
||||
len(output.shape) - 1, indices, 1
|
||||
output = torch.zeros(
|
||||
indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype
|
||||
)
|
||||
output.scatter_(len(output.shape) - 1, indices, 1)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -235,7 +259,7 @@ def top2gating(
|
|||
logits: torch.Tensor,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
use_fp32=False,
|
||||
second_expert_policy='sampling',
|
||||
second_expert_policy="sampling",
|
||||
normalize_gate_prob_before_dropping=False,
|
||||
eval_mode=False,
|
||||
moe_eval_capacity_token_fraction=0.25,
|
||||
|
@ -260,7 +284,7 @@ def top2gating(
|
|||
# Create a mask for 1st's expert per token
|
||||
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
|
||||
mask1 = one_hot(indices1_s, num_experts)
|
||||
if second_expert_policy == 'sampling':
|
||||
if second_expert_policy == "sampling":
|
||||
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
|
@ -281,13 +305,13 @@ def top2gating(
|
|||
gates1_s = gates1_s / denom_s
|
||||
gates2_s = gates2_s / denom_s
|
||||
|
||||
if second_expert_policy == 'random':
|
||||
if second_expert_policy == "random":
|
||||
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
|
||||
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
if input_mask is not None and input_mask.any():
|
||||
nonpadding = ~ input_mask
|
||||
nonpadding = ~input_mask
|
||||
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||
|
||||
|
@ -296,15 +320,22 @@ def top2gating(
|
|||
importance_scores = -1 * gates.max(dim=1)[0]
|
||||
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
|
||||
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
|
||||
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
|
||||
importance_sorted_locations1 = sorted_cumsum1[
|
||||
importance_scores.argsort(dim=0).argsort(dim=0)
|
||||
]
|
||||
|
||||
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
|
||||
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
|
||||
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
|
||||
importance_sorted_locations2 = sorted_cumsum2[
|
||||
importance_scores.argsort(dim=0).argsort(dim=0)
|
||||
]
|
||||
|
||||
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
|
||||
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2
|
||||
locations1, locations2 = (
|
||||
importance_sorted_locations1,
|
||||
importance_sorted_locations2,
|
||||
)
|
||||
else:
|
||||
locations1 = fused_cumsum_sub_one(mask1)
|
||||
locations2 = fused_cumsum_sub_one(mask2)
|
||||
|
@ -318,8 +349,12 @@ def top2gating(
|
|||
l_aux = l_aux * num_experts * num_experts
|
||||
|
||||
# for logging purposes
|
||||
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
|
||||
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
|
||||
metadata["overflow_expert1"] = (
|
||||
100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
|
||||
)
|
||||
metadata["overflow_expert2"] = (
|
||||
100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
|
||||
)
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
mask1_, mask2_ = mask1, mask2
|
||||
|
@ -327,13 +362,31 @@ def top2gating(
|
|||
mask2 = mask2 * torch.lt(locations2, capacity)
|
||||
|
||||
# for logging (percent of tokens routed to each expert)
|
||||
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||
expert1_hist = (
|
||||
100
|
||||
* torch.histc(
|
||||
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
|
||||
)
|
||||
/ num_tokens
|
||||
)
|
||||
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||
expert1_hist = (
|
||||
torch.sort(expert1_hist, dim=0, descending=True).values
|
||||
+ torch.finfo(torch.float32).tiny
|
||||
)
|
||||
|
||||
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
||||
expert2_hist = (
|
||||
100
|
||||
* torch.histc(
|
||||
(indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
|
||||
)
|
||||
/ num_tokens
|
||||
)
|
||||
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
|
||||
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
||||
expert2_hist = (
|
||||
torch.sort(expert2_hist, dim=0, descending=True).values
|
||||
+ torch.finfo(torch.float32).tiny
|
||||
)
|
||||
|
||||
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||
|
@ -355,8 +408,15 @@ def top2gating(
|
|||
if has_tutel:
|
||||
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||
return l_aux, metadata, capacity, num_experts, \
|
||||
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
||||
return (
|
||||
l_aux,
|
||||
metadata,
|
||||
capacity,
|
||||
num_experts,
|
||||
[indices1_s, indices2_s],
|
||||
[locations1_s, locations2_s],
|
||||
[gates1_s, gates2_s],
|
||||
)
|
||||
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
|
@ -369,11 +429,13 @@ def top2gating(
|
|||
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
|
||||
combine1_sec = torch.bmm(
|
||||
# einsum("se,sc->sec")
|
||||
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
||||
gates1.unsqueeze(-1),
|
||||
locations1_sc.to(gates1.dtype).unsqueeze(1),
|
||||
)
|
||||
combine2_sec = torch.bmm(
|
||||
# einsum("se,sc->sec")
|
||||
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1)
|
||||
gates2.unsqueeze(-1),
|
||||
locations2_sc.to(gates2.dtype).unsqueeze(1),
|
||||
)
|
||||
combine_weights = combine1_sec + combine2_sec
|
||||
dispatch_mask = combine_weights.bool()
|
||||
|
@ -406,7 +468,7 @@ class Top2Gate(torch.nn.Module):
|
|||
model_dim: int,
|
||||
num_experts: int,
|
||||
use_fp32=False,
|
||||
second_expert_policy='sampling',
|
||||
second_expert_policy="sampling",
|
||||
normalize_gate_prob_before_dropping=False,
|
||||
moe_eval_capacity_token_fraction=0.25,
|
||||
batch_prioritized_routing=False,
|
||||
|
|
|
@ -3,37 +3,35 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchscale.architecture.encoder import Encoder
|
||||
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding
|
||||
from torchscale.component.embedding import (
|
||||
PositionalEmbedding,
|
||||
TextEmbedding,
|
||||
VisionEmbedding,
|
||||
)
|
||||
from torchscale.component.multiway_network import MultiwayWrapper
|
||||
|
||||
|
||||
class BEiT3(nn.Module):
|
||||
|
||||
def __init__(self, args, **kwargs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
assert args.multiway
|
||||
assert args.vocab_size > 0
|
||||
assert not args.share_encoder_input_output_embed
|
||||
self.text_embed = TextEmbedding(
|
||||
args.vocab_size,
|
||||
args.encoder_embed_dim
|
||||
)
|
||||
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
|
||||
self.vision_embed = VisionEmbedding(
|
||||
args.img_size,
|
||||
args.patch_size,
|
||||
args.in_chans,
|
||||
args.encoder_embed_dim,
|
||||
contain_mask_token=True,
|
||||
prepend_cls_token=True
|
||||
prepend_cls_token=True,
|
||||
)
|
||||
embed_positions = MultiwayWrapper(
|
||||
args,
|
||||
PositionalEmbedding(
|
||||
args.max_source_positions,
|
||||
args.encoder_embed_dim
|
||||
),
|
||||
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
||||
dim=1,
|
||||
)
|
||||
self.encoder = Encoder(
|
||||
|
@ -71,7 +69,7 @@ class BEiT3(nn.Module):
|
|||
encoder_padding_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
|
||||
text_padding_position
|
||||
text_padding_position,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user