Code reformatting
This commit is contained in:
parent
1354614d44
commit
7eca1a531c
|
@ -4,7 +4,6 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.generate import cli_main
|
from fairseq_cli.generate import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.interactive import cli_main
|
from fairseq_cli.interactive import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -2,24 +2,24 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairseq import utils
|
|
||||||
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
|
||||||
from fairseq.models.transformer import (
|
|
||||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
|
||||||
)
|
|
||||||
from fairseq.modules import PositionalEmbedding
|
|
||||||
from fairseq.models.squad import SQuADHead
|
|
||||||
from omegaconf import II
|
|
||||||
from .machine_translation import MTEncoder as Encoder
|
|
||||||
from torchscale.architecture.config import EncoderConfig
|
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from fairseq import utils
|
||||||
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
|
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
|
||||||
|
from fairseq.models.squad import SQuADHead
|
||||||
|
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||||
|
from fairseq.modules import PositionalEmbedding
|
||||||
|
from omegaconf import II
|
||||||
|
|
||||||
|
from torchscale.architecture.config import EncoderConfig
|
||||||
|
|
||||||
|
from .machine_translation import MTEncoder as Encoder
|
||||||
|
|
||||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ class BertConfig(FairseqDataclass):
|
||||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||||
"--offload-activations are passed."
|
"--offload-activations are passed."
|
||||||
)
|
)
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
max_source_positions: int = field(
|
max_source_positions: int = field(
|
||||||
default=1024, metadata={"help": "max source positions"}
|
default=1024, metadata={"help": "max source positions"}
|
||||||
|
@ -118,59 +118,41 @@ class BertConfig(FairseqDataclass):
|
||||||
default="relu", metadata={"help": "activation function to use for pooler layer"}
|
default="relu", metadata={"help": "activation function to use for pooler layer"}
|
||||||
)
|
)
|
||||||
pooler_dropout: float = field(
|
pooler_dropout: float = field(
|
||||||
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
|
default=0.0,
|
||||||
|
metadata={"help": "dropout probability in the masked_lm pooler layers"},
|
||||||
)
|
)
|
||||||
# options from other parts of the config
|
# options from other parts of the config
|
||||||
# add_bos_token: bool = II("task.add_bos_token")
|
# add_bos_token: bool = II("task.add_bos_token")
|
||||||
# tokens_per_sample: int = II("task.tokens_per_sample")
|
# tokens_per_sample: int = II("task.tokens_per_sample")
|
||||||
tpu: bool = II("common.tpu")
|
tpu: bool = II("common.tpu")
|
||||||
rel_pos_buckets: int = field(
|
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
|
||||||
default=0, metadata={"help": ""}
|
max_rel_pos: int = field(default=0, metadata={"help": ""})
|
||||||
)
|
|
||||||
max_rel_pos: int = field(
|
|
||||||
default=0, metadata={"help": ""}
|
|
||||||
)
|
|
||||||
moe_freq: int = field(
|
moe_freq: int = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={
|
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
||||||
"help": "Frequency at which we insert MoE Transformer layers"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
moe_expert_count: int = field(
|
moe_expert_count: int = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Number of experts in each MoE Layer"}
|
||||||
metadata={
|
|
||||||
"help": "Number of experts in each MoE Layer"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_gating_use_fp32: bool = field(
|
moe_gating_use_fp32: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
|
||||||
"help": "Use FP32 computations in MoE top2 gating function"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_second_expert_policy: str = field(
|
moe_second_expert_policy: str = field(
|
||||||
default='sampling',
|
default="sampling",
|
||||||
metadata={
|
metadata={"help": "policy for second expert, options: all/sampling/random"},
|
||||||
"help": "policy for second expert, options: all/sampling/random"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
moe_expert_ffn_dim: Optional[int] = field(
|
moe_expert_ffn_dim: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "MoE expert FFN dimension"}
|
||||||
metadata={
|
|
||||||
"help": "MoE expert FFN dimension"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_top1_expert: Optional[bool] = field(
|
moe_top1_expert: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Use top1 gate instead of top2"}
|
||||||
metadata={
|
|
||||||
"help": "Use top1 gate instead of top2"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
|
@ -179,23 +161,29 @@ class BertConfig(FairseqDataclass):
|
||||||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
)
|
)
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
default='world_size',
|
default="world_size",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
record_a2a_perf_stats: Optional[bool] = field(
|
record_a2a_perf_stats: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
default=False,
|
||||||
|
metadata={"help": "records all to all perf stats during distributed training"},
|
||||||
)
|
)
|
||||||
dummy_a2a: Optional[bool] = field(
|
dummy_a2a: Optional[bool] = field(
|
||||||
default=False, metadata={
|
default=False,
|
||||||
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
metadata={
|
||||||
|
"help": "By passes all to all during distributed training by returning the input buffer as output"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "if true orders token by the gate prob before capacity dropping."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
ddp_rank: int = II("distributed_training.distributed_rank")
|
ddp_rank: int = II("distributed_training.distributed_rank")
|
||||||
deepnorm: Optional[bool] = field(
|
deepnorm: Optional[bool] = field(
|
||||||
|
@ -208,7 +196,6 @@ class BertConfig(FairseqDataclass):
|
||||||
|
|
||||||
@register_model("mlm", dataclass=BertConfig)
|
@register_model("mlm", dataclass=BertConfig)
|
||||||
class BertModel(BaseFairseqModel):
|
class BertModel(BaseFairseqModel):
|
||||||
|
|
||||||
def __init__(self, args, encoder):
|
def __init__(self, args, encoder):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -240,7 +227,11 @@ class BertModel(BaseFairseqModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
lm_head = cls.build_lm_head(
|
lm_head = cls.build_lm_head(
|
||||||
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
|
args,
|
||||||
|
args.encoder_embed_dim,
|
||||||
|
len(task.dictionary),
|
||||||
|
args.activation_fn,
|
||||||
|
weight=embed_tokens.weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = EncoderConfig()
|
config = EncoderConfig()
|
||||||
|
@ -269,7 +260,9 @@ class BertModel(BaseFairseqModel):
|
||||||
def output_layer(self, features, masked_tokens=None):
|
def output_layer(self, features, masked_tokens=None):
|
||||||
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
|
||||||
|
|
||||||
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
def register_classification_head(
|
||||||
|
self, name, num_classes=None, inner_dim=None, **kwargs
|
||||||
|
):
|
||||||
"""Register a classification head."""
|
"""Register a classification head."""
|
||||||
if name in self.classification_heads:
|
if name in self.classification_heads:
|
||||||
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
||||||
|
@ -277,7 +270,7 @@ class BertModel(BaseFairseqModel):
|
||||||
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
're-registering head "{}" with num_classes {} (prev: {}) '
|
're-registering head "{}" with num_classes {} (prev: {}) '
|
||||||
'and inner_dim {} (prev: {})'.format(
|
"and inner_dim {} (prev: {})".format(
|
||||||
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -295,42 +288,51 @@ class BertModel(BaseFairseqModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
def upgrade_state_dict_named(self, state_dict, name):
|
def upgrade_state_dict_named(self, state_dict, name):
|
||||||
prefix = name + '.' if name != '' else ''
|
prefix = name + "." if name != "" else ""
|
||||||
|
|
||||||
# upgrade children modules
|
# upgrade children modules
|
||||||
super().upgrade_state_dict_named(state_dict, name)
|
super().upgrade_state_dict_named(state_dict, name)
|
||||||
|
|
||||||
# Handle new classification heads present in the state dict.
|
# Handle new classification heads present in the state dict.
|
||||||
current_head_names = (
|
current_head_names = (
|
||||||
[] if not hasattr(self, 'classification_heads')
|
[]
|
||||||
|
if not hasattr(self, "classification_heads")
|
||||||
else self.classification_heads.keys()
|
else self.classification_heads.keys()
|
||||||
)
|
)
|
||||||
keys_to_delete = []
|
keys_to_delete = []
|
||||||
for k in state_dict.keys():
|
for k in state_dict.keys():
|
||||||
if not k.startswith(prefix + 'classification_heads.'):
|
if not k.startswith(prefix + "classification_heads."):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
|
head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203
|
||||||
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
|
num_classes = state_dict[
|
||||||
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
|
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
||||||
|
].size(0)
|
||||||
|
inner_dim = state_dict[
|
||||||
|
prefix + "classification_heads." + head_name + ".dense.weight"
|
||||||
|
].size(0)
|
||||||
|
|
||||||
if getattr(self.args, 'load_checkpoint_heads', False):
|
if getattr(self.args, "load_checkpoint_heads", False):
|
||||||
if head_name not in current_head_names:
|
if head_name not in current_head_names:
|
||||||
self.register_classification_head(head_name, num_classes, inner_dim)
|
self.register_classification_head(head_name, num_classes, inner_dim)
|
||||||
else:
|
else:
|
||||||
if head_name not in current_head_names:
|
if head_name not in current_head_names:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
'deleting classification head ({}) from checkpoint '
|
"deleting classification head ({}) from checkpoint "
|
||||||
'not present in current model: {}'.format(head_name, k)
|
"not present in current model: {}".format(head_name, k)
|
||||||
)
|
)
|
||||||
keys_to_delete.append(k)
|
keys_to_delete.append(k)
|
||||||
elif (
|
elif (
|
||||||
num_classes != self.classification_heads[head_name].out_proj.out_features
|
num_classes
|
||||||
or inner_dim != self.classification_heads[head_name].dense.out_features
|
!= self.classification_heads[head_name].out_proj.out_features
|
||||||
|
or inner_dim
|
||||||
|
!= self.classification_heads[head_name].dense.out_features
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
'deleting classification head ({}) from checkpoint '
|
"deleting classification head ({}) from checkpoint "
|
||||||
'with different dimensions than current model: {}'.format(head_name, k)
|
"with different dimensions than current model: {}".format(
|
||||||
|
head_name, k
|
||||||
|
)
|
||||||
)
|
)
|
||||||
keys_to_delete.append(k)
|
keys_to_delete.append(k)
|
||||||
for k in keys_to_delete:
|
for k in keys_to_delete:
|
||||||
|
@ -338,12 +340,12 @@ class BertModel(BaseFairseqModel):
|
||||||
|
|
||||||
# Copy any newly-added classification heads into the state dict
|
# Copy any newly-added classification heads into the state dict
|
||||||
# with their current weights.
|
# with their current weights.
|
||||||
if hasattr(self, 'classification_heads'):
|
if hasattr(self, "classification_heads"):
|
||||||
cur_state = self.classification_heads.state_dict()
|
cur_state = self.classification_heads.state_dict()
|
||||||
for k, v in cur_state.items():
|
for k, v in cur_state.items():
|
||||||
if prefix + 'classification_heads.' + k not in state_dict:
|
if prefix + "classification_heads." + k not in state_dict:
|
||||||
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
|
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
||||||
state_dict[prefix + 'classification_heads.' + k] = v
|
state_dict[prefix + "classification_heads." + k] = v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -354,7 +356,9 @@ class BertModel(BaseFairseqModel):
|
||||||
masked_tokens=None,
|
masked_tokens=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
|
encoder_out = self.encoder(
|
||||||
|
src_tokens, features_only=True, return_all_hiddens=return_all_hiddens
|
||||||
|
)
|
||||||
x, extra = encoder_out["encoder_out"], encoder_out
|
x, extra = encoder_out["encoder_out"], encoder_out
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
@ -455,7 +459,7 @@ def base_unilm_architecture(args):
|
||||||
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
|
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
|
||||||
|
|
||||||
# Model training is not stable without this
|
# Model training is not stable without this
|
||||||
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
|
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||||
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
|
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
|
||||||
|
|
||||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
||||||
|
|
|
@ -9,10 +9,9 @@
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
|
||||||
|
|
||||||
from fairseq import utils
|
import torch
|
||||||
from fairseq import distributed_utils
|
from fairseq import distributed_utils, utils
|
||||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqIncrementalDecoder,
|
FairseqIncrementalDecoder,
|
||||||
|
@ -20,14 +19,13 @@ from fairseq.models import (
|
||||||
register_model,
|
register_model,
|
||||||
register_model_architecture,
|
register_model_architecture,
|
||||||
)
|
)
|
||||||
from fairseq.models.transformer import (
|
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
|
||||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding,
|
|
||||||
)
|
|
||||||
from fairseq.modules import PositionalEmbedding
|
from fairseq.modules import PositionalEmbedding
|
||||||
from torchscale.architecture.decoder import Decoder
|
|
||||||
from torchscale.architecture.config import DecoderConfig
|
|
||||||
from omegaconf import II
|
from omegaconf import II
|
||||||
|
|
||||||
|
from torchscale.architecture.config import DecoderConfig
|
||||||
|
from torchscale.architecture.decoder import Decoder
|
||||||
|
|
||||||
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -104,49 +102,34 @@ class LanguageConfig(FairseqDataclass):
|
||||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||||
"--offload-activations are passed."
|
"--offload-activations are passed."
|
||||||
)
|
)
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
moe_freq: int = field(
|
moe_freq: int = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={
|
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
|
||||||
"help": "Frequency at which we insert MoE Transformer layers"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
moe_expert_count: int = field(
|
moe_expert_count: int = field(
|
||||||
default=0,
|
default=0, metadata={"help": "Number of experts in each MoE Layer"}
|
||||||
metadata={
|
|
||||||
"help": "Number of experts in each MoE Layer"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_gating_use_fp32: bool = field(
|
moe_gating_use_fp32: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
|
||||||
"help": "Use FP32 computations in MoE top2 gating function"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_second_expert_policy: str = field(
|
moe_second_expert_policy: str = field(
|
||||||
default='sampling',
|
default="sampling",
|
||||||
metadata={
|
metadata={"help": "policy for second expert, options: all/sampling/random"},
|
||||||
"help": "policy for second expert, options: all/sampling/random"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_normalize_gate_prob_before_dropping: bool = field(
|
moe_normalize_gate_prob_before_dropping: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
|
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
moe_expert_ffn_dim: Optional[int] = field(
|
moe_expert_ffn_dim: Optional[int] = field(
|
||||||
default=None,
|
default=None, metadata={"help": "MoE expert FFN dimension"}
|
||||||
metadata={
|
|
||||||
"help": "MoE expert FFN dimension"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_top1_expert: Optional[bool] = field(
|
moe_top1_expert: Optional[bool] = field(
|
||||||
default=False,
|
default=False, metadata={"help": "Use top1 gate instead of top2"}
|
||||||
metadata={
|
|
||||||
"help": "Use top1 gate instead of top2"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
moe_eval_capacity_token_fraction: Optional[float] = field(
|
moe_eval_capacity_token_fraction: Optional[float] = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
|
@ -155,23 +138,29 @@ class LanguageConfig(FairseqDataclass):
|
||||||
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
"Default: 0.25, Fraction of tokens as capacity during validation, "
|
||||||
"if set to negative, use same as training. range: (0.0, 1.0]."
|
"if set to negative, use same as training. range: (0.0, 1.0]."
|
||||||
)
|
)
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
moe_normalize_expert_grad: Optional[str] = field(
|
moe_normalize_expert_grad: Optional[str] = field(
|
||||||
default='world_size',
|
default="world_size",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
record_a2a_perf_stats: Optional[bool] = field(
|
record_a2a_perf_stats: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "records all to all perf stats during distributed training"}
|
default=False,
|
||||||
|
metadata={"help": "records all to all perf stats during distributed training"},
|
||||||
)
|
)
|
||||||
dummy_a2a: Optional[bool] = field(
|
dummy_a2a: Optional[bool] = field(
|
||||||
default=False, metadata={
|
default=False,
|
||||||
"help": "By passes all to all during distributed training by returning the input buffer as output"}
|
metadata={
|
||||||
|
"help": "By passes all to all during distributed training by returning the input buffer as output"
|
||||||
|
},
|
||||||
)
|
)
|
||||||
moe_batch_prioritized_routing: Optional[bool] = field(
|
moe_batch_prioritized_routing: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "if true orders token by the gate prob before capacity dropping."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
use_xmoe: Optional[bool] = field(
|
use_xmoe: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -205,7 +194,6 @@ class LanguageConfig(FairseqDataclass):
|
||||||
|
|
||||||
@register_model("lm", dataclass=LanguageConfig)
|
@register_model("lm", dataclass=LanguageConfig)
|
||||||
class LanguageModel(FairseqLanguageModel):
|
class LanguageModel(FairseqLanguageModel):
|
||||||
|
|
||||||
def __init__(self, args, decoder):
|
def __init__(self, args, decoder):
|
||||||
self.args = args
|
self.args = args
|
||||||
super().__init__(decoder)
|
super().__init__(decoder)
|
||||||
|
@ -245,19 +233,17 @@ class LanguageModel(FairseqLanguageModel):
|
||||||
args.decoder_embed_dim, len(task.dictionary), bias=False
|
args.decoder_embed_dim, len(task.dictionary), bias=False
|
||||||
)
|
)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if getattr(args, "moe_freq", 0) > 0 and (
|
||||||
getattr(args, 'moe_freq', 0) > 0
|
getattr(args, "fp16", False)
|
||||||
and (
|
and not getattr(args, "memory_efficient_fp16", False)
|
||||||
getattr(args, 'fp16', False)
|
and getattr(args, "ddp_backend", None) != "fully_sharded"
|
||||||
and not getattr(args, 'memory_efficient_fp16', False)
|
|
||||||
and getattr(args, 'ddp_backend', None) != "fully_sharded"
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
assert args.fp16_no_flatten_grads, \
|
assert (
|
||||||
"If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
args.fp16_no_flatten_grads
|
||||||
|
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
|
||||||
|
|
||||||
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
args.ddp_rank = distributed_utils.get_data_parallel_rank()
|
||||||
|
|
||||||
|
@ -281,7 +267,6 @@ class LanguageModel(FairseqLanguageModel):
|
||||||
|
|
||||||
|
|
||||||
class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
class LMDecoder(Decoder, FairseqIncrementalDecoder):
|
||||||
|
|
||||||
def forward(self, src_tokens, **kwargs):
|
def forward(self, src_tokens, **kwargs):
|
||||||
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||||
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
|
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)
|
||||||
|
|
|
@ -6,12 +6,12 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairseq import utils
|
from fairseq import distributed_utils, utils
|
||||||
from fairseq.distributed import utils as fsdp_wrap
|
from fairseq.distributed import utils as fsdp_wrap
|
||||||
from fairseq import distributed_utils
|
|
||||||
from fairseq.models import (
|
from fairseq.models import (
|
||||||
FairseqEncoder,
|
FairseqEncoder,
|
||||||
FairseqEncoderDecoderModel,
|
FairseqEncoderDecoderModel,
|
||||||
|
@ -20,12 +20,13 @@ from fairseq.models import (
|
||||||
)
|
)
|
||||||
from fairseq.models.transformer import Embedding
|
from fairseq.models.transformer import Embedding
|
||||||
from fairseq.modules import PositionalEmbedding
|
from fairseq.modules import PositionalEmbedding
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from torchscale.architecture.config import DecoderConfig, EncoderConfig
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
from torchscale.architecture.config import EncoderConfig, DecoderConfig
|
|
||||||
from .language_modeling import LMDecoder as MTDecoder
|
from .language_modeling import LMDecoder as MTDecoder
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||||
|
@ -35,7 +36,6 @@ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
||||||
|
|
||||||
@register_model("mt")
|
@register_model("mt")
|
||||||
class TranslationModel(FairseqEncoderDecoderModel):
|
class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
|
|
||||||
def __init__(self, args, encoder, decoder):
|
def __init__(self, args, encoder, decoder):
|
||||||
super().__init__(encoder, decoder)
|
super().__init__(encoder, decoder)
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -269,7 +269,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
args.decoder_embed_dim, len(tgt_dict), bias=False
|
args.decoder_embed_dim, len(tgt_dict), bias=False
|
||||||
)
|
)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder = cls.build_encoder(
|
encoder = cls.build_encoder(
|
||||||
|
@ -320,7 +320,9 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_decoder(cls, args, embed_tokens, embed_positions, output_projection, dictionary):
|
def build_decoder(
|
||||||
|
cls, args, embed_tokens, embed_positions, output_projection, dictionary
|
||||||
|
):
|
||||||
config = DecoderConfig()
|
config = DecoderConfig()
|
||||||
config.override(args)
|
config.override(args)
|
||||||
|
|
||||||
|
@ -342,10 +344,7 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
features_only: bool = False,
|
features_only: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
encoder_out = self.encoder(
|
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
||||||
src_tokens,
|
|
||||||
return_all_hiddens=return_all_hiddens
|
|
||||||
)
|
|
||||||
decoder_out = self.decoder(
|
decoder_out = self.decoder(
|
||||||
prev_output_tokens,
|
prev_output_tokens,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
@ -365,15 +364,20 @@ class TranslationModel(FairseqEncoderDecoderModel):
|
||||||
|
|
||||||
|
|
||||||
class MTEncoder(Encoder, FairseqEncoder):
|
class MTEncoder(Encoder, FairseqEncoder):
|
||||||
|
|
||||||
def forward(self, src_tokens, **kwargs):
|
def forward(self, src_tokens, **kwargs):
|
||||||
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
|
||||||
return super().forward(src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs)
|
return super().forward(
|
||||||
|
src_tokens=src_tokens, encoder_padding_mask=self_attn_padding_mask, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def reorder_encoder_out(self, encoder_out, new_order):
|
def reorder_encoder_out(self, encoder_out, new_order):
|
||||||
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
|
new_encoder_out = encoder_out["encoder_out"].index_select(1, new_order)
|
||||||
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(0, new_order)
|
new_encoder_embedding = encoder_out["encoder_embedding"].index_select(
|
||||||
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(0, new_order)
|
0, new_order
|
||||||
|
)
|
||||||
|
new_encoder_padding_mask = encoder_out["encoder_padding_mask"].index_select(
|
||||||
|
0, new_order
|
||||||
|
)
|
||||||
|
|
||||||
encoder_states = encoder_out["encoder_states"]
|
encoder_states = encoder_out["encoder_states"]
|
||||||
if len(encoder_states) > 0:
|
if len(encoder_states) > 0:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from infinibatch.iterators import CheckpointableIterator
|
from infinibatch.iterators import CheckpointableIterator
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,7 +26,6 @@ class BaseBatchGen(CheckpointableIterator):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _move_to_tensor(self, batch):
|
def _move_to_tensor(self, batch):
|
||||||
|
|
||||||
def to_tensor(x):
|
def to_tensor(x):
|
||||||
return torch.tensor(x)
|
return torch.tensor(x)
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import os
|
|
||||||
import numpy as np
|
|
||||||
import itertools
|
|
||||||
import copy
|
import copy
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from infinibatch import iterators
|
from infinibatch import iterators
|
||||||
|
|
||||||
from .basic_loader import BaseBatchGen
|
from .basic_loader import BaseBatchGen
|
||||||
from .utils import NativeCheckpointableIterator, WeightIterator
|
from .utils import NativeCheckpointableIterator, WeightIterator
|
||||||
|
|
||||||
|
|
||||||
class MLMLoader(BaseBatchGen):
|
class MLMLoader(BaseBatchGen):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -62,9 +62,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
log_empty_buffer_warning=True and self.shard_id == 0,
|
log_empty_buffer_warning=True and self.shard_id == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefetch_batches = iterators.MapIterator(
|
prefetch_batches = iterators.MapIterator(prefetch_batches, self._move_to_tensor)
|
||||||
prefetch_batches, self._move_to_tensor
|
|
||||||
)
|
|
||||||
|
|
||||||
self._iter = prefetch_batches
|
self._iter = prefetch_batches
|
||||||
|
|
||||||
|
@ -73,25 +71,25 @@ class MLMLoader(BaseBatchGen):
|
||||||
weights = []
|
weights = []
|
||||||
|
|
||||||
for data in self.data:
|
for data in self.data:
|
||||||
multilingual_iters.append(
|
multilingual_iters.append(self._tokenize(data))
|
||||||
self._tokenize(data)
|
if "weight" in data:
|
||||||
)
|
weights.append(float(data["weight"]))
|
||||||
if 'weight' in data:
|
|
||||||
weights.append(float(data['weight']))
|
|
||||||
else:
|
else:
|
||||||
weights.append(int(data['count']))
|
weights.append(int(data["count"]))
|
||||||
|
|
||||||
if len(multilingual_iters) == 1:
|
if len(multilingual_iters) == 1:
|
||||||
return multilingual_iters[0]
|
return multilingual_iters[0]
|
||||||
|
|
||||||
sampling_iterator = WeightIterator(weights)
|
sampling_iterator = WeightIterator(weights)
|
||||||
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
control_iterator = NativeCheckpointableIterator(sampling_iterator)
|
||||||
tokenized_lines = iterators.MultiplexIterator(control_iterator, multilingual_iters)
|
tokenized_lines = iterators.MultiplexIterator(
|
||||||
|
control_iterator, multilingual_iters
|
||||||
|
)
|
||||||
|
|
||||||
return tokenized_lines
|
return tokenized_lines
|
||||||
|
|
||||||
def _tokenize(self, data):
|
def _tokenize(self, data):
|
||||||
'''
|
"""
|
||||||
data:
|
data:
|
||||||
{
|
{
|
||||||
'source': list[Path],
|
'source': list[Path],
|
||||||
|
@ -100,17 +98,16 @@ class MLMLoader(BaseBatchGen):
|
||||||
'weight': float,
|
'weight': float,
|
||||||
'name': str,
|
'name': str,
|
||||||
}
|
}
|
||||||
'''
|
"""
|
||||||
dataset = list(
|
dataset = list(
|
||||||
zip(
|
zip(
|
||||||
data['source'],
|
data["source"],
|
||||||
itertools.repeat(data['source_lang']),
|
itertools.repeat(data["source_lang"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
chunk_files = \
|
chunk_files = iterators.InfinitePermutationSourceIterator(
|
||||||
iterators.InfinitePermutationSourceIterator(
|
|
||||||
dataset,
|
dataset,
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
|
@ -118,15 +115,18 @@ class MLMLoader(BaseBatchGen):
|
||||||
instance_rank=self.shard_id,
|
instance_rank=self.shard_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_files = \
|
chunk_files = iterators.ChunkedSourceIterator(
|
||||||
iterators.ChunkedSourceIterator(
|
|
||||||
dataset,
|
dataset,
|
||||||
num_instances=self.num_shards,
|
num_instances=self.num_shards,
|
||||||
instance_rank=self.shard_id,
|
instance_rank=self.shard_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenized_lines = iterators.SelectManyIterator(chunk_files, lambda files: self._read_from_files(*files))
|
tokenized_lines = iterators.SelectManyIterator(
|
||||||
tokenized_lines = iterators.SamplingRandomMapIterator(tokenized_lines, self._prepare, self.seed)
|
chunk_files, lambda files: self._read_from_files(*files)
|
||||||
|
)
|
||||||
|
tokenized_lines = iterators.SamplingRandomMapIterator(
|
||||||
|
tokenized_lines, self._prepare, self.seed
|
||||||
|
)
|
||||||
|
|
||||||
return tokenized_lines
|
return tokenized_lines
|
||||||
|
|
||||||
|
@ -134,13 +134,20 @@ class MLMLoader(BaseBatchGen):
|
||||||
|
|
||||||
if self.max_sentences is not None:
|
if self.max_sentences is not None:
|
||||||
if self.batch_read_ahead > 0:
|
if self.batch_read_ahead > 0:
|
||||||
lines = iterators.BlockwiseShuffleIterator(lines, self.batch_read_ahead, self.seed)
|
lines = iterators.BlockwiseShuffleIterator(
|
||||||
|
lines, self.batch_read_ahead, self.seed
|
||||||
|
)
|
||||||
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
|
batches = iterators.FixedBatchIterator(lines, self.max_sentences)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def dynamic_batch_size(sample):
|
def dynamic_batch_size(sample):
|
||||||
lengths = [len(x) for x in sample]
|
lengths = [len(x) for x in sample]
|
||||||
batch_size = self.max_tokens // max(lengths)
|
batch_size = self.max_tokens // max(lengths)
|
||||||
batch_size = batch_size // self.required_batch_size_multiple * self.required_batch_size_multiple
|
batch_size = (
|
||||||
|
batch_size
|
||||||
|
// self.required_batch_size_multiple
|
||||||
|
* self.required_batch_size_multiple
|
||||||
|
)
|
||||||
return max(1, batch_size)
|
return max(1, batch_size)
|
||||||
|
|
||||||
batches = iterators.BucketedReadaheadBatchIterator(
|
batches = iterators.BucketedReadaheadBatchIterator(
|
||||||
|
@ -160,38 +167,56 @@ class MLMLoader(BaseBatchGen):
|
||||||
s2s_source_max_length = max([len(x[2]) for x in batch])
|
s2s_source_max_length = max([len(x[2]) for x in batch])
|
||||||
s2s_target_max_length = max([len(x[3]) for x in batch])
|
s2s_target_max_length = max([len(x[3]) for x in batch])
|
||||||
|
|
||||||
mlm_source_ids = np.full(shape=(batch_size, mlm_source_max_length), dtype=np.int32,
|
mlm_source_ids = np.full(
|
||||||
fill_value=self.dictionary.pad())
|
shape=(batch_size, mlm_source_max_length),
|
||||||
mlm_target_ids = np.full(shape=(batch_size, mlm_target_max_length), dtype=np.int32,
|
dtype=np.int32,
|
||||||
fill_value=self.dictionary.pad())
|
fill_value=self.dictionary.pad(),
|
||||||
s2s_source_ids = np.full(shape=(batch_size, s2s_source_max_length), dtype=np.int32,
|
)
|
||||||
fill_value=self.dictionary.pad())
|
mlm_target_ids = np.full(
|
||||||
s2s_target_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
shape=(batch_size, mlm_target_max_length),
|
||||||
fill_value=self.dictionary.pad())
|
dtype=np.int32,
|
||||||
s2s_prev_input_ids = np.full(shape=(batch_size, s2s_target_max_length-1), dtype=np.int32,
|
fill_value=self.dictionary.pad(),
|
||||||
fill_value=self.dictionary.pad())
|
)
|
||||||
|
s2s_source_ids = np.full(
|
||||||
|
shape=(batch_size, s2s_source_max_length),
|
||||||
|
dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad(),
|
||||||
|
)
|
||||||
|
s2s_target_ids = np.full(
|
||||||
|
shape=(batch_size, s2s_target_max_length - 1),
|
||||||
|
dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad(),
|
||||||
|
)
|
||||||
|
s2s_prev_input_ids = np.full(
|
||||||
|
shape=(batch_size, s2s_target_max_length - 1),
|
||||||
|
dtype=np.int32,
|
||||||
|
fill_value=self.dictionary.pad(),
|
||||||
|
)
|
||||||
|
|
||||||
for i, (mlm_input_ids, mlm_label_ids, s2s_input_ids, s2s_label_ids) in enumerate(batch):
|
for i, (
|
||||||
mlm_source_ids[i, :len(mlm_input_ids)] = mlm_input_ids
|
mlm_input_ids,
|
||||||
mlm_target_ids[i, :len(mlm_label_ids)] = mlm_label_ids
|
mlm_label_ids,
|
||||||
s2s_source_ids[i, :len(s2s_input_ids)] = s2s_input_ids
|
s2s_input_ids,
|
||||||
s2s_target_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[1:]
|
s2s_label_ids,
|
||||||
s2s_prev_input_ids[i, :len(s2s_label_ids)-1] = s2s_label_ids[:-1]
|
) in enumerate(batch):
|
||||||
|
mlm_source_ids[i, : len(mlm_input_ids)] = mlm_input_ids
|
||||||
|
mlm_target_ids[i, : len(mlm_label_ids)] = mlm_label_ids
|
||||||
|
s2s_source_ids[i, : len(s2s_input_ids)] = s2s_input_ids
|
||||||
|
s2s_target_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[1:]
|
||||||
|
s2s_prev_input_ids[i, : len(s2s_label_ids) - 1] = s2s_label_ids[:-1]
|
||||||
|
|
||||||
ret_batch = {
|
ret_batch = {
|
||||||
'net_input': {
|
"net_input": {
|
||||||
'src_tokens': mlm_source_ids.astype(np.int64),
|
"src_tokens": mlm_source_ids.astype(np.int64),
|
||||||
},
|
},
|
||||||
'target': mlm_target_ids.astype(np.int64),
|
"target": mlm_target_ids.astype(np.int64),
|
||||||
'nsentences': batch_size,
|
"nsentences": batch_size,
|
||||||
'ntokens': sum([len(x[0]) for x in batch]),
|
"ntokens": sum([len(x[0]) for x in batch]),
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret_batch
|
return ret_batch
|
||||||
|
|
||||||
padded_batches = iterators.MapIterator(
|
padded_batches = iterators.MapIterator(batches, collate)
|
||||||
batches, collate
|
|
||||||
)
|
|
||||||
|
|
||||||
return padded_batches
|
return padded_batches
|
||||||
|
|
||||||
|
@ -221,7 +246,6 @@ class MLMLoader(BaseBatchGen):
|
||||||
return nonmasked_tokens, masked_tokens
|
return nonmasked_tokens, masked_tokens
|
||||||
|
|
||||||
def _span_corruption(self, _random, doc):
|
def _span_corruption(self, _random, doc):
|
||||||
|
|
||||||
def mask_tokens(i):
|
def mask_tokens(i):
|
||||||
return f"<mask_{i}>"
|
return f"<mask_{i}>"
|
||||||
|
|
||||||
|
@ -237,7 +261,9 @@ class MLMLoader(BaseBatchGen):
|
||||||
else:
|
else:
|
||||||
possible_split_positions = list(range(1, noise_tokens_num))
|
possible_split_positions = list(range(1, noise_tokens_num))
|
||||||
_random.shuffle(possible_split_positions)
|
_random.shuffle(possible_split_positions)
|
||||||
noise_split_positions = sorted(possible_split_positions[:noise_spans_num-1])
|
noise_split_positions = sorted(
|
||||||
|
possible_split_positions[: noise_spans_num - 1]
|
||||||
|
)
|
||||||
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
noise_split_positions = [0] + noise_split_positions + [noise_tokens_num]
|
||||||
|
|
||||||
possible_insert_positions = list(range(nonnoise_tokens_num))
|
possible_insert_positions = list(range(nonnoise_tokens_num))
|
||||||
|
@ -248,7 +274,7 @@ class MLMLoader(BaseBatchGen):
|
||||||
last_end = 0
|
last_end = 0
|
||||||
for i in range(noise_spans_num):
|
for i in range(noise_spans_num):
|
||||||
start_pos = noise_insert_positions[i] + noise_split_positions[i]
|
start_pos = noise_insert_positions[i] + noise_split_positions[i]
|
||||||
end_pos = noise_insert_positions[i] + noise_split_positions[i+1]
|
end_pos = noise_insert_positions[i] + noise_split_positions[i + 1]
|
||||||
mask_id = self.dictionary.indices[mask_tokens(i)]
|
mask_id = self.dictionary.indices[mask_tokens(i)]
|
||||||
|
|
||||||
if getattr(self.args, "remove_target_sentinel", False):
|
if getattr(self.args, "remove_target_sentinel", False):
|
||||||
|
@ -273,23 +299,25 @@ class MLMLoader(BaseBatchGen):
|
||||||
file_path = os.path.join(self.data_dir, source_file)
|
file_path = os.path.join(self.data_dir, source_file)
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
print('| file {} not exists'.format(file_path), flush=True)
|
print("| file {} not exists".format(file_path), flush=True)
|
||||||
return iter([]) # skip bad file
|
return iter([]) # skip bad file
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf8') as f:
|
with open(file_path, "r", encoding="utf8") as f:
|
||||||
lines = f.read().strip().split('\n')
|
lines = f.read().strip().split("\n")
|
||||||
|
|
||||||
doc = [self.dictionary.bos()]
|
doc = [self.dictionary.bos()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line == "":
|
if line == "":
|
||||||
if self.sample_break_mode == 'complete_doc':
|
if self.sample_break_mode == "complete_doc":
|
||||||
# data.append(doc)
|
# data.append(doc)
|
||||||
yield doc
|
yield doc
|
||||||
doc = [self.dictionary.bos()]
|
doc = [self.dictionary.bos()]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
tokenized_line = self.tokenizer.EncodeAsPieces(line)
|
||||||
tokenized_id = [self.dictionary.index(token) for token in tokenized_line] + [self.dictionary.eos_index]
|
tokenized_id = [
|
||||||
|
self.dictionary.index(token) for token in tokenized_line
|
||||||
|
] + [self.dictionary.eos_index]
|
||||||
|
|
||||||
if len(tokenized_id) > self.tokens_per_sample:
|
if len(tokenized_id) > self.tokens_per_sample:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import numpy as np
|
import collections
|
||||||
from random import Random
|
from random import Random
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional
|
||||||
import collections
|
|
||||||
|
import numpy as np
|
||||||
from infinibatch import iterators
|
from infinibatch import iterators
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +18,9 @@ def apply_to_sample(f, sample):
|
||||||
return f(x)
|
return f(x)
|
||||||
elif isinstance(x, collections.OrderedDict):
|
elif isinstance(x, collections.OrderedDict):
|
||||||
# OrderedDict has attributes that needs to be preserved
|
# OrderedDict has attributes that needs to be preserved
|
||||||
od = collections.OrderedDict((key, _apply(value)) for key, value in x.items())
|
od = collections.OrderedDict(
|
||||||
|
(key, _apply(value)) for key, value in x.items()
|
||||||
|
)
|
||||||
od.__dict__ = x.__dict__
|
od.__dict__ = x.__dict__
|
||||||
return od
|
return od
|
||||||
elif isinstance(x, dict):
|
elif isinstance(x, dict):
|
||||||
|
@ -40,14 +43,15 @@ class NativeCheckpointableIterator(iterators.CheckpointableIterator):
|
||||||
self.setstate(None)
|
self.setstate(None)
|
||||||
|
|
||||||
def getstate(self) -> Dict:
|
def getstate(self) -> Dict:
|
||||||
return {'num_items_yielded': self._num_items_yielded}
|
return {"num_items_yielded": self._num_items_yielded}
|
||||||
|
|
||||||
def setstate(self, checkpoint: Optional[Dict]):
|
def setstate(self, checkpoint: Optional[Dict]):
|
||||||
self._iterator = iter(self._input_iterable)
|
self._iterator = iter(self._input_iterable)
|
||||||
self._num_items_yielded = iterators._advance_iterator(
|
self._num_items_yielded = (
|
||||||
self._iterator,
|
iterators._advance_iterator(self._iterator, checkpoint["num_items_yielded"])
|
||||||
checkpoint['num_items_yielded']
|
if checkpoint is not None
|
||||||
) if checkpoint is not None else 0
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
item = next(self._iterator)
|
item = next(self._iterator)
|
||||||
|
@ -73,7 +77,9 @@ class WeightIterator(object):
|
||||||
|
|
||||||
def setstate(self, checkpoint):
|
def setstate(self, checkpoint):
|
||||||
self._random_state = checkpoint["random_state"] if checkpoint else None
|
self._random_state = checkpoint["random_state"] if checkpoint else None
|
||||||
self._random = None # this will trigger the lazy initialization in self.__next__
|
self._random = (
|
||||||
|
None # this will trigger the lazy initialization in self.__next__
|
||||||
|
)
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
if self._random is None:
|
if self._random is None:
|
||||||
|
|
|
@ -1,23 +1,25 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from argparse import Namespace
|
|
||||||
import json
|
|
||||||
from omegaconf import MISSING, II
|
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
from fairseq import utils
|
from fairseq import utils
|
||||||
from fairseq.data import Dictionary
|
from fairseq.data import Dictionary
|
||||||
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||||
from fairseq.tasks import FairseqTask, register_task
|
from fairseq.tasks import FairseqTask, register_task
|
||||||
|
from omegaconf import II, MISSING
|
||||||
|
|
||||||
from .data.mlm_loader import MLMLoader
|
from .data.mlm_loader import MLMLoader
|
||||||
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
|
|
||||||
import sentencepiece as spm
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -109,21 +111,16 @@ class PretrainingConfig(FairseqDataclass):
|
||||||
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
|
required_batch_size_multiple: int = II("dataset.required_batch_size_multiple")
|
||||||
spm_model: str = field(
|
spm_model: str = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={"help": "sentencepice model to tokenize the data"},
|
||||||
"help": "sentencepice model to tokenize the data"
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
dict_file: str = field(
|
dict_file: str = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={"help": ""},
|
||||||
"help": ""
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_task("pretraining", dataclass=PretrainingConfig)
|
@register_task("pretraining", dataclass=PretrainingConfig)
|
||||||
class PLMTask(FairseqTask):
|
class PLMTask(FairseqTask):
|
||||||
|
|
||||||
def __init__(self, cfg, dictionary, tokenizer):
|
def __init__(self, cfg, dictionary, tokenizer):
|
||||||
super().__init__(cfg)
|
super().__init__(cfg)
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
@ -156,9 +153,9 @@ class PLMTask(FairseqTask):
|
||||||
|
|
||||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||||
self.datasets[split] = {
|
self.datasets[split] = {
|
||||||
'data': json.load(open(f'{self.cfg.data}/json/{split}.json')),
|
"data": json.load(open(f"{self.cfg.data}/json/{split}.json")),
|
||||||
'data_dir': self.cfg.data,
|
"data_dir": self.cfg.data,
|
||||||
'shuffle': True if split == 'train' else False,
|
"shuffle": True if split == "train" else False,
|
||||||
}
|
}
|
||||||
self.datasets[split] = Namespace(**self.datasets[split])
|
self.datasets[split] = Namespace(**self.datasets[split])
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import models
|
import models
|
||||||
import tasks
|
import tasks
|
||||||
|
|
||||||
from fairseq_cli.train import cli_main
|
from fairseq_cli.train import cli_main
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -1,17 +1,21 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import torch
|
|
||||||
import warnings
|
|
||||||
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
|
||||||
import torch.distributed as dist
|
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from fairseq.utils import multi_tensor_l2norm_available, multi_tensor_total_norm
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None) -> torch.Tensor:
|
def clip_grad_norm_(
|
||||||
|
params, max_norm, moe_expert_count, aggregate_norm_fn=None
|
||||||
|
) -> torch.Tensor:
|
||||||
def grad_exists(p):
|
def grad_exists(p):
|
||||||
return p is not None and getattr(p, "grad", None) is not None
|
return p is not None and getattr(p, "grad", None) is not None
|
||||||
|
|
||||||
if isinstance(params, torch.Tensor):
|
if isinstance(params, torch.Tensor):
|
||||||
params = [params]
|
params = [params]
|
||||||
params = list(params)
|
params = list(params)
|
||||||
|
@ -59,7 +63,9 @@ def clip_grad_norm_(params, max_norm, moe_expert_count, aggregate_norm_fn=None)
|
||||||
for split_grads in [expert_grads, sharded_grads]:
|
for split_grads in [expert_grads, sharded_grads]:
|
||||||
if len(split_grads) == 0:
|
if len(split_grads) == 0:
|
||||||
continue
|
continue
|
||||||
split_norm = torch.norm(torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads]))
|
split_norm = torch.norm(
|
||||||
|
torch.stack([torch.norm(g, p=2, dtype=torch.float32) for g in split_grads])
|
||||||
|
)
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
split_norm.pow_(2)
|
split_norm.pow_(2)
|
||||||
dist.all_reduce(split_norm)
|
dist.all_reduce(split_norm)
|
||||||
|
|
15
setup.py
15
setup.py
|
@ -2,6 +2,7 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
@ -10,19 +11,15 @@ setup(
|
||||||
author="TorchScale Team",
|
author="TorchScale Team",
|
||||||
author_email="Shuming.Ma@microsoft.com",
|
author_email="Shuming.Ma@microsoft.com",
|
||||||
description="Transformers at any scale",
|
description="Transformers at any scale",
|
||||||
long_description=open("README.md", "r", encoding='utf-8').read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
keywords="Transformers at any scale",
|
keywords="Transformers at any scale",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
url="https://github.com/msranlp/torchscale",
|
url="https://github.com/msranlp/torchscale",
|
||||||
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
|
||||||
"tests.*", "tests"]),
|
install_requires=["apex", "torch>=1.8", "fairscale==0.4.0", "timm==0.4.12"],
|
||||||
install_requires=['apex',
|
python_requires=">=3.8.0",
|
||||||
'torch>=1.8',
|
|
||||||
'fairscale==0.4.0',
|
|
||||||
'timm==0.4.12'],
|
|
||||||
python_requires='>=3.8.0',
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Programming Language :: Python :: 3',
|
"Programming Language :: Python :: 3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,9 +2,10 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from torchscale.architecture.config import DecoderConfig
|
from torchscale.architecture.config import DecoderConfig
|
||||||
from torchscale.architecture.decoder import Decoder
|
from torchscale.architecture.decoder import Decoder
|
||||||
import torch
|
|
||||||
|
|
||||||
testcases = [
|
testcases = [
|
||||||
{},
|
{},
|
||||||
|
@ -20,7 +21,7 @@ testcases = [
|
||||||
{"multiway": True},
|
{"multiway": True},
|
||||||
{"share_decoder_input_output_embed": True},
|
{"share_decoder_input_output_embed": True},
|
||||||
{"checkpoint_activations": True},
|
{"checkpoint_activations": True},
|
||||||
{"fsdp": True}
|
{"fsdp": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,10 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from torchscale.architecture.config import EncoderConfig
|
from torchscale.architecture.config import EncoderConfig
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
import torch
|
|
||||||
|
|
||||||
testcases = [
|
testcases = [
|
||||||
{},
|
{},
|
||||||
|
@ -20,7 +21,7 @@ testcases = [
|
||||||
{"multiway": True},
|
{"multiway": True},
|
||||||
{"share_encoder_input_output_embed": True},
|
{"share_encoder_input_output_embed": True},
|
||||||
{"checkpoint_activations": True},
|
{"checkpoint_activations": True},
|
||||||
{"fsdp": True}
|
{"fsdp": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,11 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from torchscale.architecture.config import EncoderDecoderConfig
|
from torchscale.architecture.config import EncoderDecoderConfig
|
||||||
from torchscale.architecture.encoder_decoder import EncoderDecoder
|
from torchscale.architecture.encoder_decoder import EncoderDecoder
|
||||||
from torchscale.component.embedding import TextEmbedding, PositionalEmbedding
|
from torchscale.component.embedding import PositionalEmbedding, TextEmbedding
|
||||||
import torch
|
|
||||||
|
|
||||||
testcases = [
|
testcases = [
|
||||||
{},
|
{},
|
||||||
|
@ -16,13 +17,18 @@ testcases = [
|
||||||
{"no_scale_embedding": False},
|
{"no_scale_embedding": False},
|
||||||
{"layernorm_embedding": True},
|
{"layernorm_embedding": True},
|
||||||
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
{"rel_pos_buckets": 32, "max_rel_pos": 256},
|
||||||
{"deepnorm": True, "subln": False, "encoder_normalize_before": False, "decoder_normalize_before": False},
|
{
|
||||||
|
"deepnorm": True,
|
||||||
|
"subln": False,
|
||||||
|
"encoder_normalize_before": False,
|
||||||
|
"decoder_normalize_before": False,
|
||||||
|
},
|
||||||
{"bert_init": True},
|
{"bert_init": True},
|
||||||
{"multiway": True},
|
{"multiway": True},
|
||||||
{"share_decoder_input_output_embed": True},
|
{"share_decoder_input_output_embed": True},
|
||||||
{"share_all_embeddings": True},
|
{"share_all_embeddings": True},
|
||||||
{"checkpoint_activations": True},
|
{"checkpoint_activations": True},
|
||||||
{"fsdp": True}
|
{"fsdp": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,8 +39,12 @@ def test_decoder(args):
|
||||||
config,
|
config,
|
||||||
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
encoder_embed_tokens=TextEmbedding(64000, config.encoder_embed_dim),
|
||||||
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
decoder_embed_tokens=TextEmbedding(64000, config.decoder_embed_dim),
|
||||||
encoder_embed_positions=PositionalEmbedding(config.max_source_positions, config.encoder_embed_dim),
|
encoder_embed_positions=PositionalEmbedding(
|
||||||
decoder_embed_positions=PositionalEmbedding(config.max_target_positions, config.decoder_embed_dim),
|
config.max_source_positions, config.encoder_embed_dim
|
||||||
|
),
|
||||||
|
decoder_embed_positions=PositionalEmbedding(
|
||||||
|
config.max_target_positions, config.decoder_embed_dim
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
src_tokens = torch.ones(2, 20).long()
|
src_tokens = torch.ones(2, 20).long()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
|
|
||||||
class EncoderConfig(object):
|
class EncoderConfig(object):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", 768)
|
||||||
|
@ -19,9 +20,13 @@ class EncoderConfig(object):
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||||
|
"moe_eval_capacity_token_fraction", 0.25
|
||||||
|
)
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||||
|
"moe_normalize_gate_prob_before_dropping", False
|
||||||
|
)
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
@ -29,7 +34,9 @@ class EncoderConfig(object):
|
||||||
self.subln = kwargs.pop("subln", True)
|
self.subln = kwargs.pop("subln", True)
|
||||||
self.bert_init = kwargs.pop("bert_init", False)
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
self.share_encoder_input_output_embed = kwargs.pop("share_encoder_input_output_embed", False)
|
self.share_encoder_input_output_embed = kwargs.pop(
|
||||||
|
"share_encoder_input_output_embed", False
|
||||||
|
)
|
||||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
# Text
|
# Text
|
||||||
|
@ -78,9 +85,13 @@ class DecoderConfig(object):
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||||
|
"moe_eval_capacity_token_fraction", 0.25
|
||||||
|
)
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||||
|
"moe_normalize_gate_prob_before_dropping", False
|
||||||
|
)
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
@ -88,7 +99,9 @@ class DecoderConfig(object):
|
||||||
self.subln = kwargs.pop("subln", True)
|
self.subln = kwargs.pop("subln", True)
|
||||||
self.bert_init = kwargs.pop("bert_init", False)
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
self.share_decoder_input_output_embed = kwargs.pop(
|
||||||
|
"share_decoder_input_output_embed", False
|
||||||
|
)
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
# Text
|
# Text
|
||||||
|
@ -138,9 +151,13 @@ class EncoderDecoderConfig(object):
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
|
||||||
self.moe_eval_capacity_token_fraction = kwargs.pop("moe_eval_capacity_token_fraction", 0.25)
|
self.moe_eval_capacity_token_fraction = kwargs.pop(
|
||||||
|
"moe_eval_capacity_token_fraction", 0.25
|
||||||
|
)
|
||||||
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
|
||||||
self.moe_normalize_gate_prob_before_dropping = kwargs.pop("moe_normalize_gate_prob_before_dropping", False)
|
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
|
||||||
|
"moe_normalize_gate_prob_before_dropping", False
|
||||||
|
)
|
||||||
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
self.use_xmoe = kwargs.pop("use_xmoe", False)
|
||||||
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
|
||||||
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
|
||||||
|
@ -149,7 +166,9 @@ class EncoderDecoderConfig(object):
|
||||||
self.bert_init = kwargs.pop("bert_init", False)
|
self.bert_init = kwargs.pop("bert_init", False)
|
||||||
self.multiway = kwargs.pop("multiway", False)
|
self.multiway = kwargs.pop("multiway", False)
|
||||||
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
self.share_all_embeddings = kwargs.pop("share_all_embeddings", False)
|
||||||
self.share_decoder_input_output_embed = kwargs.pop("share_decoder_input_output_embed", False)
|
self.share_decoder_input_output_embed = kwargs.pop(
|
||||||
|
"share_decoder_input_output_embed", False
|
||||||
|
)
|
||||||
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
self.max_source_positions = kwargs.pop("max_source_positions", 1024)
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
|
|
|
@ -2,22 +2,23 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
|
||||||
from fairscale.nn import checkpoint_wrapper, wrap
|
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
|
||||||
|
from torchscale.architecture.utils import init_bert_params
|
||||||
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||||
from torchscale.component.multihead_attention import MultiheadAttention
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
|
||||||
from torchscale.component.droppath import DropPath
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
|
||||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||||
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayer(nn.Module):
|
class DecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -31,7 +32,9 @@ class DecoderLayer(nn.Module):
|
||||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||||
|
|
||||||
if args.drop_path_rate > 0:
|
if args.drop_path_rate > 0:
|
||||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[depth]
|
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
|
||||||
|
depth
|
||||||
|
]
|
||||||
self.drop_path = DropPath(drop_path_prob)
|
self.drop_path = DropPath(drop_path_prob)
|
||||||
else:
|
else:
|
||||||
self.drop_path = None
|
self.drop_path = None
|
||||||
|
@ -206,7 +209,6 @@ class DecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -228,7 +230,11 @@ class Decoder(nn.Module):
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
self.embed_positions = embed_positions
|
self.embed_positions = embed_positions
|
||||||
|
|
||||||
if output_projection is None and not args.no_output_layer and args.vocab_size > 0:
|
if (
|
||||||
|
output_projection is None
|
||||||
|
and not args.no_output_layer
|
||||||
|
and args.vocab_size > 0
|
||||||
|
):
|
||||||
self.output_projection = self.build_output_projection(args)
|
self.output_projection = self.build_output_projection(args)
|
||||||
else:
|
else:
|
||||||
self.output_projection = output_projection
|
self.output_projection = output_projection
|
||||||
|
@ -286,7 +292,12 @@ class Decoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
|
||||||
for name, p in self.named_parameters():
|
for name, p in self.named_parameters():
|
||||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
if (
|
||||||
|
"fc1" in name
|
||||||
|
or "fc2" in name
|
||||||
|
or "out_proj" in name
|
||||||
|
or "v_proj" in name
|
||||||
|
):
|
||||||
p.data.div_(init_scale)
|
p.data.div_(init_scale)
|
||||||
|
|
||||||
if args.subln:
|
if args.subln:
|
||||||
|
@ -295,9 +306,14 @@ class Decoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
|
||||||
for name, p in self.named_parameters():
|
for name, p in self.named_parameters():
|
||||||
if 'encoder_attn' in name:
|
if "encoder_attn" in name:
|
||||||
continue
|
continue
|
||||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
if (
|
||||||
|
"fc1" in name
|
||||||
|
or "fc2" in name
|
||||||
|
or "out_proj" in name
|
||||||
|
or "v_proj" in name
|
||||||
|
):
|
||||||
p.data.mul_(init_scale)
|
p.data.mul_(init_scale)
|
||||||
|
|
||||||
def build_output_projection(
|
def build_output_projection(
|
||||||
|
@ -316,16 +332,12 @@ class Decoder(nn.Module):
|
||||||
args.decoder_embed_dim, args.vocab_size, bias=False
|
args.decoder_embed_dim, args.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
|
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
|
||||||
)
|
)
|
||||||
return output_projection
|
return output_projection
|
||||||
|
|
||||||
def build_decoder_layer(
|
def build_decoder_layer(
|
||||||
self,
|
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
|
||||||
args,
|
|
||||||
depth,
|
|
||||||
is_moe_layer=False,
|
|
||||||
is_encoder_decoder=False
|
|
||||||
):
|
):
|
||||||
layer = DecoderLayer(
|
layer = DecoderLayer(
|
||||||
args,
|
args,
|
||||||
|
@ -347,7 +359,9 @@ class Decoder(nn.Module):
|
||||||
):
|
):
|
||||||
positions = None
|
positions = None
|
||||||
if self.embed_positions is not None:
|
if self.embed_positions is not None:
|
||||||
positions = self.embed_positions(tokens, incremental_state=incremental_state)
|
positions = self.embed_positions(
|
||||||
|
tokens, incremental_state=incremental_state
|
||||||
|
)
|
||||||
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
tokens = tokens[:, -1:]
|
tokens = tokens[:, -1:]
|
||||||
|
@ -381,7 +395,9 @@ class Decoder(nn.Module):
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
# embed tokens and positions
|
# embed tokens and positions
|
||||||
x, _ = self.forward_embedding(prev_output_tokens, token_embeddings, incremental_state)
|
x, _ = self.forward_embedding(
|
||||||
|
prev_output_tokens, token_embeddings, incremental_state
|
||||||
|
)
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
# relative postion
|
# relative postion
|
||||||
|
@ -389,9 +405,7 @@ class Decoder(nn.Module):
|
||||||
slen = prev_output_tokens.size(1)
|
slen = prev_output_tokens.size(1)
|
||||||
if self.self_attn_relative_position is not None:
|
if self.self_attn_relative_position is not None:
|
||||||
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
self_attn_rel_pos_bias = self.self_attn_relative_position(
|
||||||
batch_size=x.size(1),
|
batch_size=x.size(1), qlen=slen, klen=slen
|
||||||
qlen=slen,
|
|
||||||
klen=slen
|
|
||||||
)
|
)
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
|
self_attn_rel_pos_bias = self_attn_rel_pos_bias[:, -1:, :]
|
||||||
|
@ -416,7 +430,11 @@ class Decoder(nn.Module):
|
||||||
for idx, layer in enumerate(self.layers):
|
for idx, layer in enumerate(self.layers):
|
||||||
if incremental_state is None:
|
if incremental_state is None:
|
||||||
self_attn_mask = torch.triu(
|
self_attn_mask = torch.triu(
|
||||||
torch.zeros([x.size(0), x.size(0)]).float().fill_(float("-inf")).type_as(x), 1
|
torch.zeros([x.size(0), x.size(0)])
|
||||||
|
.float()
|
||||||
|
.fill_(float("-inf"))
|
||||||
|
.type_as(x),
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self_attn_mask = None
|
self_attn_mask = None
|
||||||
|
@ -426,7 +444,9 @@ class Decoder(nn.Module):
|
||||||
x, layer_attn, _, l_aux_i = layer(
|
x, layer_attn, _, l_aux_i = layer(
|
||||||
x,
|
x,
|
||||||
encoder_out["encoder_out"] if encoder_out is not None else None,
|
encoder_out["encoder_out"] if encoder_out is not None else None,
|
||||||
encoder_out["encoder_padding_mask"] if encoder_out is not None else None,
|
encoder_out["encoder_padding_mask"]
|
||||||
|
if encoder_out is not None
|
||||||
|
else None,
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
incremental_state[idx] if incremental_state is not None else None,
|
||||||
self_attn_mask=self_attn_mask,
|
self_attn_mask=self_attn_mask,
|
||||||
self_attn_padding_mask=self_attn_padding_mask,
|
self_attn_padding_mask=self_attn_padding_mask,
|
||||||
|
@ -444,7 +464,11 @@ class Decoder(nn.Module):
|
||||||
if not features_only:
|
if not features_only:
|
||||||
x = self.output_layer(x)
|
x = self.output_layer(x)
|
||||||
|
|
||||||
return x, {"inner_states": inner_states, "l_aux": l_aux, "attn": [layer_attn.mean(dim=0)]}
|
return x, {
|
||||||
|
"inner_states": inner_states,
|
||||||
|
"l_aux": l_aux,
|
||||||
|
"attn": [layer_attn.mean(dim=0)],
|
||||||
|
}
|
||||||
|
|
||||||
def output_layer(self, features):
|
def output_layer(self, features):
|
||||||
return self.output_projection(features)
|
return self.output_projection(features)
|
||||||
|
|
|
@ -2,30 +2,25 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
|
||||||
from fairscale.nn import checkpoint_wrapper, wrap
|
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
|
||||||
|
from torchscale.architecture.utils import init_bert_params
|
||||||
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
|
||||||
from torchscale.component.multihead_attention import MultiheadAttention
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
from torchscale.component.multiway_network import MultiwayWrapper, set_split_position
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
|
||||||
from torchscale.component.multiway_network import set_split_position, MultiwayWrapper
|
|
||||||
from torchscale.component.droppath import DropPath
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
|
||||||
from torchscale.component.relative_position_bias import RelativePositionBias
|
from torchscale.component.relative_position_bias import RelativePositionBias
|
||||||
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayer(nn.Module):
|
class EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, args, depth, is_moe_layer=False, is_encoder_decoder=False):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
args,
|
|
||||||
depth,
|
|
||||||
is_moe_layer=False,
|
|
||||||
is_encoder_decoder=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
self.embed_dim = args.encoder_embed_dim
|
self.embed_dim = args.encoder_embed_dim
|
||||||
|
@ -34,7 +29,9 @@ class EncoderLayer(nn.Module):
|
||||||
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
self.dropout_module = torch.nn.Dropout(args.dropout, inplace=True)
|
||||||
|
|
||||||
if args.drop_path_rate > 0:
|
if args.drop_path_rate > 0:
|
||||||
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[depth]
|
drop_path_prob = np.linspace(0, args.drop_path_rate, args.encoder_layers)[
|
||||||
|
depth
|
||||||
|
]
|
||||||
self.drop_path = DropPath(drop_path_prob)
|
self.drop_path = DropPath(drop_path_prob)
|
||||||
else:
|
else:
|
||||||
self.drop_path = None
|
self.drop_path = None
|
||||||
|
@ -49,7 +46,7 @@ class EncoderLayer(nn.Module):
|
||||||
self.build_ffn(
|
self.build_ffn(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.args,
|
self.args,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert not self.args.multiway
|
assert not self.args.multiway
|
||||||
|
@ -77,7 +74,12 @@ class EncoderLayer(nn.Module):
|
||||||
|
|
||||||
if args.deepnorm:
|
if args.deepnorm:
|
||||||
if is_encoder_decoder:
|
if is_encoder_decoder:
|
||||||
self.alpha = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) * 0.81
|
self.alpha = (
|
||||||
|
math.pow(
|
||||||
|
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
|
||||||
|
)
|
||||||
|
* 0.81
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
self.alpha = math.pow(2.0 * args.encoder_layers, 0.25)
|
||||||
else:
|
else:
|
||||||
|
@ -107,13 +109,7 @@ class EncoderLayer(nn.Module):
|
||||||
def residual_connection(self, x, residual):
|
def residual_connection(self, x, residual):
|
||||||
return residual * self.alpha + x
|
return residual * self.alpha + x
|
||||||
|
|
||||||
def forward(
|
def forward(self, x, encoder_padding_mask, attn_mask=None, rel_pos=None):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
encoder_padding_mask,
|
|
||||||
attn_mask=None,
|
|
||||||
rel_pos=None
|
|
||||||
):
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)
|
||||||
|
|
||||||
|
@ -158,7 +154,6 @@ class EncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -179,13 +174,20 @@ class Encoder(nn.Module):
|
||||||
self.embed_tokens = embed_tokens
|
self.embed_tokens = embed_tokens
|
||||||
self.embed_positions = embed_positions
|
self.embed_positions = embed_positions
|
||||||
|
|
||||||
if output_projection is None and not is_encoder_decoder and not args.no_output_layer and args.vocab_size > 0:
|
if (
|
||||||
|
output_projection is None
|
||||||
|
and not is_encoder_decoder
|
||||||
|
and not args.no_output_layer
|
||||||
|
and args.vocab_size > 0
|
||||||
|
):
|
||||||
self.output_projection = self.build_output_projection(args)
|
self.output_projection = self.build_output_projection(args)
|
||||||
else:
|
else:
|
||||||
self.output_projection = output_projection
|
self.output_projection = output_projection
|
||||||
|
|
||||||
if args.layernorm_embedding:
|
if args.layernorm_embedding:
|
||||||
self.layernorm_embedding = MultiwayWrapper(args, LayerNorm(embed_dim), dim=1)
|
self.layernorm_embedding = MultiwayWrapper(
|
||||||
|
args, LayerNorm(embed_dim), dim=1
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.layernorm_embedding = None
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
|
@ -199,7 +201,7 @@ class Encoder(nn.Module):
|
||||||
args,
|
args,
|
||||||
depth=i,
|
depth=i,
|
||||||
is_moe_layer=is_moe_layer,
|
is_moe_layer=is_moe_layer,
|
||||||
is_encoder_decoder=is_encoder_decoder
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.num_layers = len(self.layers)
|
self.num_layers = len(self.layers)
|
||||||
|
@ -223,20 +225,39 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
if args.deepnorm:
|
if args.deepnorm:
|
||||||
if is_encoder_decoder:
|
if is_encoder_decoder:
|
||||||
init_scale = math.pow(math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625) / 1.15
|
init_scale = (
|
||||||
|
math.pow(
|
||||||
|
math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625
|
||||||
|
)
|
||||||
|
/ 1.15
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
init_scale = math.pow(8.0 * args.encoder_layers, 0.25)
|
||||||
for name, p in self.named_parameters():
|
for name, p in self.named_parameters():
|
||||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
if (
|
||||||
|
"fc1" in name
|
||||||
|
or "fc2" in name
|
||||||
|
or "out_proj" in name
|
||||||
|
or "v_proj" in name
|
||||||
|
):
|
||||||
p.data.div_(init_scale)
|
p.data.div_(init_scale)
|
||||||
|
|
||||||
if args.subln:
|
if args.subln:
|
||||||
if is_encoder_decoder:
|
if is_encoder_decoder:
|
||||||
init_scale = math.sqrt(math.log(3 * args.decoder_layers) * math.log(2 * args.encoder_layers) / 3)
|
init_scale = math.sqrt(
|
||||||
|
math.log(3 * args.decoder_layers)
|
||||||
|
* math.log(2 * args.encoder_layers)
|
||||||
|
/ 3
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
init_scale = math.sqrt(math.log(args.encoder_layers * 2))
|
||||||
for name, p in self.named_parameters():
|
for name, p in self.named_parameters():
|
||||||
if 'fc1' in name or 'fc2' in name or 'out_proj' in name or 'v_proj' in name:
|
if (
|
||||||
|
"fc1" in name
|
||||||
|
or "fc2" in name
|
||||||
|
or "out_proj" in name
|
||||||
|
or "v_proj" in name
|
||||||
|
):
|
||||||
p.data.mul_(init_scale)
|
p.data.mul_(init_scale)
|
||||||
|
|
||||||
def build_output_projection(
|
def build_output_projection(
|
||||||
|
@ -244,7 +265,7 @@ class Encoder(nn.Module):
|
||||||
args,
|
args,
|
||||||
):
|
):
|
||||||
if args.share_encoder_input_output_embed:
|
if args.share_encoder_input_output_embed:
|
||||||
assert args.encoder_embedding_type == 'language'
|
assert args.encoder_embedding_type == "language"
|
||||||
output_projection = torch.nn.Linear(
|
output_projection = torch.nn.Linear(
|
||||||
self.embed_tokens.weight.shape[1],
|
self.embed_tokens.weight.shape[1],
|
||||||
self.embed_tokens.weight.shape[0],
|
self.embed_tokens.weight.shape[0],
|
||||||
|
@ -256,22 +277,18 @@ class Encoder(nn.Module):
|
||||||
args.encoder_embed_dim, args.vocab_size, bias=False
|
args.encoder_embed_dim, args.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
output_projection.weight, mean=0, std=args.encoder_embed_dim ** -0.5
|
output_projection.weight, mean=0, std=args.encoder_embed_dim**-0.5
|
||||||
)
|
)
|
||||||
return output_projection
|
return output_projection
|
||||||
|
|
||||||
def build_encoder_layer(
|
def build_encoder_layer(
|
||||||
self,
|
self, args, depth, is_moe_layer=False, is_encoder_decoder=False
|
||||||
args,
|
|
||||||
depth,
|
|
||||||
is_moe_layer=False,
|
|
||||||
is_encoder_decoder=False
|
|
||||||
):
|
):
|
||||||
layer = EncoderLayer(
|
layer = EncoderLayer(
|
||||||
args,
|
args,
|
||||||
depth,
|
depth,
|
||||||
is_moe_layer=is_moe_layer,
|
is_moe_layer=is_moe_layer,
|
||||||
is_encoder_decoder=is_encoder_decoder
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
)
|
)
|
||||||
if args.checkpoint_activations:
|
if args.checkpoint_activations:
|
||||||
layer = checkpoint_wrapper(layer)
|
layer = checkpoint_wrapper(layer)
|
||||||
|
@ -312,13 +329,12 @@ class Encoder(nn.Module):
|
||||||
if encoder_padding_mask is None:
|
if encoder_padding_mask is None:
|
||||||
if src_tokens is not None:
|
if src_tokens is not None:
|
||||||
encoder_padding_mask = torch.zeros_like(
|
encoder_padding_mask = torch.zeros_like(
|
||||||
src_tokens,
|
src_tokens, device=src_tokens.device
|
||||||
device=src_tokens.device
|
|
||||||
).bool()
|
).bool()
|
||||||
else:
|
else:
|
||||||
encoder_padding_mask = torch.zeros(
|
encoder_padding_mask = torch.zeros(
|
||||||
[token_embeddings.size(0), token_embeddings.size(1)],
|
[token_embeddings.size(0), token_embeddings.size(1)],
|
||||||
device=token_embeddings.device
|
device=token_embeddings.device,
|
||||||
).bool()
|
).bool()
|
||||||
|
|
||||||
if multiway_split_position is not None:
|
if multiway_split_position is not None:
|
||||||
|
@ -338,16 +354,13 @@ class Encoder(nn.Module):
|
||||||
rel_pos_bias = None
|
rel_pos_bias = None
|
||||||
if self.relative_position is not None:
|
if self.relative_position is not None:
|
||||||
rel_pos_bias = self.relative_position(
|
rel_pos_bias = self.relative_position(
|
||||||
batch_size=x.size(1),
|
batch_size=x.size(1), qlen=x.size(0), klen=x.size(0)
|
||||||
qlen=x.size(0),
|
|
||||||
klen=x.size(0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
l_aux = []
|
l_aux = []
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x, l_aux_i = layer(
|
x, l_aux_i = layer(
|
||||||
x, encoder_padding_mask=encoder_padding_mask,
|
x, encoder_padding_mask=encoder_padding_mask, rel_pos=rel_pos_bias
|
||||||
rel_pos=rel_pos_bias
|
|
||||||
)
|
)
|
||||||
if return_all_hiddens:
|
if return_all_hiddens:
|
||||||
assert encoder_states is not None
|
assert encoder_states is not None
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchscale.architecture.encoder import Encoder
|
|
||||||
from torchscale.architecture.decoder import Decoder
|
from torchscale.architecture.decoder import Decoder
|
||||||
|
from torchscale.architecture.encoder import Encoder
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoder(nn.Module):
|
class EncoderDecoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -51,10 +51,7 @@ class EncoderDecoder(nn.Module):
|
||||||
features_only=False,
|
features_only=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
encoder_out = self.encoder(
|
encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
|
||||||
src_tokens,
|
|
||||||
return_all_hiddens=return_all_hiddens
|
|
||||||
)
|
|
||||||
decoder_out = self.decoder(
|
decoder_out = self.decoder(
|
||||||
prev_output_tokens,
|
prev_output_tokens,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torchscale.component.multihead_attention import MultiheadAttention
|
from torchscale.component.multihead_attention import MultiheadAttention
|
||||||
from torchscale.component.multiway_network import MultiwayNetwork
|
from torchscale.component.multiway_network import MultiwayNetwork
|
||||||
|
|
||||||
|
|
||||||
def init_bert_params(module):
|
def init_bert_params(module):
|
||||||
|
|
||||||
def normal_(data):
|
def normal_(data):
|
||||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
# Copyright (c) 2022 Microsoft
|
# Copyright (c) 2022 Microsoft
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
from timm.models.layers import drop_path
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from timm.models.layers import drop_path
|
||||||
|
|
||||||
|
|
||||||
class DropPath(nn.Module):
|
class DropPath(nn.Module):
|
||||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
"""
|
|
||||||
def __init__(self, drop_prob=None):
|
def __init__(self, drop_prob=None):
|
||||||
super(DropPath, self).__init__()
|
super(DropPath, self).__init__()
|
||||||
self.drop_prob = drop_prob
|
self.drop_prob = drop_prob
|
||||||
|
@ -16,4 +16,4 @@ class DropPath(nn.Module):
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return 'p={}'.format(self.drop_prob)
|
return "p={}".format(self.drop_prob)
|
||||||
|
|
|
@ -7,22 +7,12 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class VisionLanguageEmbedding(nn.Module):
|
class VisionLanguageEmbedding(nn.Module):
|
||||||
|
def __init__(self, text_embed, vision_embed):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_embed,
|
|
||||||
vision_embed
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.text_embed = text_embed
|
self.text_embed = text_embed
|
||||||
self.vision_embed = vision_embed
|
self.vision_embed = vision_embed
|
||||||
|
|
||||||
def forward(
|
def forward(self, textual_tokens, visual_tokens, **kwargs):
|
||||||
self,
|
|
||||||
textual_tokens,
|
|
||||||
visual_tokens,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if textual_tokens is None:
|
if textual_tokens is None:
|
||||||
return self.vision_embed(visual_tokens)
|
return self.vision_embed(visual_tokens)
|
||||||
|
|
||||||
|
@ -36,8 +26,8 @@ class VisionLanguageEmbedding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class VisionEmbedding(nn.Module):
|
class VisionEmbedding(nn.Module):
|
||||||
""" Image to Patch Embedding
|
"""Image to Patch Embedding"""
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
img_size=224,
|
img_size=224,
|
||||||
|
@ -45,7 +35,7 @@ class VisionEmbedding(nn.Module):
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
embed_dim=768,
|
embed_dim=768,
|
||||||
contain_mask_token=False,
|
contain_mask_token=False,
|
||||||
prepend_cls_token=False
|
prepend_cls_token=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
img_size = (img_size, img_size)
|
img_size = (img_size, img_size)
|
||||||
|
@ -56,7 +46,9 @@ class VisionEmbedding(nn.Module):
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
|
|
||||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
self.proj = nn.Conv2d(
|
||||||
|
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||||
|
)
|
||||||
|
|
||||||
if contain_mask_token:
|
if contain_mask_token:
|
||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
@ -68,15 +60,11 @@ class VisionEmbedding(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.cls_token = None
|
self.cls_token = None
|
||||||
|
|
||||||
def forward(
|
def forward(self, x, masked_position=None, **kwargs):
|
||||||
self,
|
|
||||||
x,
|
|
||||||
masked_position=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
assert (
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
H == self.img_size[0] and W == self.img_size[1]
|
||||||
|
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
batch_size, seq_len, _ = x.size()
|
batch_size, seq_len, _ = x.size()
|
||||||
|
@ -88,21 +76,21 @@ class VisionEmbedding(nn.Module):
|
||||||
x = x * (1 - w) + mask_token * w
|
x = x * (1 - w) + mask_token * w
|
||||||
|
|
||||||
if self.cls_token is not None:
|
if self.cls_token is not None:
|
||||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
cls_tokens = self.cls_token.expand(
|
||||||
|
batch_size, -1, -1
|
||||||
|
) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TextEmbedding(nn.Embedding):
|
class TextEmbedding(nn.Embedding):
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim ** -0.5)
|
nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5)
|
||||||
self._fill_padding_idx_with_zero()
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
|
|
||||||
class PositionalEmbedding(nn.Embedding):
|
class PositionalEmbedding(nn.Embedding):
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
|
@ -111,7 +99,9 @@ class PositionalEmbedding(nn.Embedding):
|
||||||
):
|
):
|
||||||
if positions is None:
|
if positions is None:
|
||||||
# being consistent with Fairseq, which starts from 2.
|
# being consistent with Fairseq, which starts from 2.
|
||||||
positions = torch.arange(2, x.size(1)+2, device=x.device).long().unsqueeze(0)
|
positions = (
|
||||||
|
torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0)
|
||||||
|
)
|
||||||
return F.embedding(
|
return F.embedding(
|
||||||
positions,
|
positions,
|
||||||
self.weight,
|
self.weight,
|
||||||
|
|
|
@ -35,13 +35,19 @@ class set_torch_seed(object):
|
||||||
|
|
||||||
|
|
||||||
def make_experts(args, embed_dim, expert_ffn_dim):
|
def make_experts(args, embed_dim, expert_ffn_dim):
|
||||||
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
|
world_size = (
|
||||||
|
1
|
||||||
|
if not torch.distributed.is_initialized()
|
||||||
|
else torch.distributed.get_world_size()
|
||||||
|
)
|
||||||
expert_list = []
|
expert_list = []
|
||||||
ddp_rank = args.ddp_rank
|
ddp_rank = args.ddp_rank
|
||||||
start_seed = torch.randint(1000000, (1,)).item()
|
start_seed = torch.randint(1000000, (1,)).item()
|
||||||
# at least as many experts than gpus
|
# at least as many experts than gpus
|
||||||
if args.moe_expert_count >= world_size:
|
if args.moe_expert_count >= world_size:
|
||||||
assert args.moe_expert_count % world_size == 0, f'{args.moe_expert_count}, {world_size}'
|
assert (
|
||||||
|
args.moe_expert_count % world_size == 0
|
||||||
|
), f"{args.moe_expert_count}, {world_size}"
|
||||||
local_moe_expert_count = args.moe_expert_count // world_size
|
local_moe_expert_count = args.moe_expert_count // world_size
|
||||||
for i in range(local_moe_expert_count):
|
for i in range(local_moe_expert_count):
|
||||||
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
||||||
|
@ -52,11 +58,13 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
||||||
args.activation_fn,
|
args.activation_fn,
|
||||||
args.dropout,
|
args.dropout,
|
||||||
args.activation_dropout,
|
args.activation_dropout,
|
||||||
args.subln
|
args.subln,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert world_size % args.moe_expert_count == 0, f'{world_size}, {args.moe_expert_count}'
|
assert (
|
||||||
|
world_size % args.moe_expert_count == 0
|
||||||
|
), f"{world_size}, {args.moe_expert_count}"
|
||||||
|
|
||||||
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count):
|
||||||
expert_list.append(
|
expert_list.append(
|
||||||
|
@ -66,7 +74,7 @@ def make_experts(args, embed_dim, expert_ffn_dim):
|
||||||
args.activation_fn,
|
args.activation_fn,
|
||||||
args.dropout,
|
args.dropout,
|
||||||
args.activation_dropout,
|
args.activation_dropout,
|
||||||
args.subln
|
args.subln,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
experts = nn.ModuleList(expert_list)
|
experts = nn.ModuleList(expert_list)
|
||||||
|
@ -83,7 +91,6 @@ def get_activation_fn(activation):
|
||||||
|
|
||||||
|
|
||||||
class FeedForwardNetwork(nn.Module):
|
class FeedForwardNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
@ -91,12 +98,14 @@ class FeedForwardNetwork(nn.Module):
|
||||||
activation_fn,
|
activation_fn,
|
||||||
dropout,
|
dropout,
|
||||||
activation_dropout,
|
activation_dropout,
|
||||||
subln=False
|
subln=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||||
self.activation_dropout_module = torch.nn.Dropout(activation_dropout, inplace=True)
|
self.activation_dropout_module = torch.nn.Dropout(
|
||||||
|
activation_dropout, inplace=True
|
||||||
|
)
|
||||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
|
||||||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
|
||||||
|
|
|
@ -2,15 +2,16 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from apex.normalization import FusedLayerNorm as LayerNorm
|
from apex.normalization import FusedLayerNorm as LayerNorm
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from .multiway_network import MultiwayWrapper
|
from .multiway_network import MultiwayWrapper
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.Module):
|
class MultiheadAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
|
@ -25,7 +26,7 @@ class MultiheadAttention(nn.Module):
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
self.scaling = self.head_dim ** -0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
self.self_attention = self_attention
|
self.self_attention = self_attention
|
||||||
self.encoder_decoder_attention = encoder_decoder_attention
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
@ -34,8 +35,14 @@ class MultiheadAttention(nn.Module):
|
||||||
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
||||||
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
|
self.out_proj = MultiwayWrapper(
|
||||||
self.inner_attn_ln = MultiwayWrapper(args, LayerNorm(self.embed_dim)) if subln and self.self_attention else None
|
args, nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
|
)
|
||||||
|
self.inner_attn_ln = (
|
||||||
|
MultiwayWrapper(args, LayerNorm(self.embed_dim))
|
||||||
|
if subln and self.self_attention
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
self.dropout_module = torch.nn.Dropout(dropout, inplace=True)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
@ -76,12 +83,20 @@ class MultiheadAttention(nn.Module):
|
||||||
|
|
||||||
if incremental_state is not None:
|
if incremental_state is not None:
|
||||||
if "prev_key" in incremental_state:
|
if "prev_key" in incremental_state:
|
||||||
prev_key = incremental_state["prev_key"].view(bsz * self.num_heads, -1, self.head_dim)
|
prev_key = incremental_state["prev_key"].view(
|
||||||
prev_value = incremental_state["prev_value"].view(bsz * self.num_heads, -1, self.head_dim)
|
bsz * self.num_heads, -1, self.head_dim
|
||||||
|
)
|
||||||
|
prev_value = incremental_state["prev_value"].view(
|
||||||
|
bsz * self.num_heads, -1, self.head_dim
|
||||||
|
)
|
||||||
k = torch.cat([prev_key, k], dim=1)
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
v = torch.cat([prev_value, v], dim=1)
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
incremental_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
incremental_state["prev_key"] = k.view(
|
||||||
incremental_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
bsz, self.num_heads, -1, self.head_dim
|
||||||
|
)
|
||||||
|
incremental_state["prev_value"] = v.view(
|
||||||
|
bsz, self.num_heads, -1, self.head_dim
|
||||||
|
)
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
@ -103,7 +118,9 @@ class MultiheadAttention(nn.Module):
|
||||||
rel_pos = rel_pos.view(attn_weights.size())
|
rel_pos = rel_pos.view(attn_weights.size())
|
||||||
attn_weights = attn_weights + rel_pos
|
attn_weights = attn_weights + rel_pos
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
|
||||||
|
attn_weights
|
||||||
|
)
|
||||||
attn_probs = self.dropout_module(attn_weights)
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
attn = torch.bmm(attn_probs, v)
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -13,16 +14,14 @@ def MultiwayWrapper(args, module, dim=0):
|
||||||
|
|
||||||
|
|
||||||
def set_split_position(position):
|
def set_split_position(position):
|
||||||
|
|
||||||
def apply_fn(module):
|
def apply_fn(module):
|
||||||
if hasattr(module, 'split_position'):
|
if hasattr(module, "split_position"):
|
||||||
module.split_position = position
|
module.split_position = position
|
||||||
|
|
||||||
return apply_fn
|
return apply_fn
|
||||||
|
|
||||||
|
|
||||||
class MultiwayNetwork(nn.Module):
|
class MultiwayNetwork(nn.Module):
|
||||||
|
|
||||||
def __init__(self, module, dim=0):
|
def __init__(self, module, dim=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
@ -36,7 +35,11 @@ class MultiwayNetwork(nn.Module):
|
||||||
return self.A(x, **kwargs)
|
return self.A(x, **kwargs)
|
||||||
if self.split_position == 0:
|
if self.split_position == 0:
|
||||||
return self.B(x, **kwargs)
|
return self.B(x, **kwargs)
|
||||||
x1, x2 = torch.split(x, [self.split_position, x.size(self.dim)-self.split_position], dim=self.dim)
|
x1, x2 = torch.split(
|
||||||
|
x,
|
||||||
|
[self.split_position, x.size(self.dim) - self.split_position],
|
||||||
|
dim=self.dim,
|
||||||
|
)
|
||||||
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
# x1, x2 = x[:self.split_position], x[self.split_position:]
|
||||||
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
|
||||||
return torch.cat([y1, y2], dim=self.dim)
|
return torch.cat([y1, y2], dim=self.dim)
|
||||||
|
|
|
@ -2,17 +2,14 @@
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionBias(nn.Module):
|
class RelativePositionBias(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
|
||||||
bidirectional=True,
|
|
||||||
num_buckets=32,
|
|
||||||
max_distance=128,
|
|
||||||
n_heads=12
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.bidirectional = bidirectional
|
self.bidirectional = bidirectional
|
||||||
|
@ -23,10 +20,7 @@ class RelativePositionBias(nn.Module):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _relative_position_bucket(
|
def _relative_position_bucket(
|
||||||
relative_position,
|
relative_position, bidirectional=True, num_buckets=32, max_distance=128
|
||||||
bidirectional=True,
|
|
||||||
num_buckets=32,
|
|
||||||
max_distance=128
|
|
||||||
):
|
):
|
||||||
ret = 0
|
ret = 0
|
||||||
n = -relative_position
|
n = -relative_position
|
||||||
|
@ -41,24 +35,28 @@ class RelativePositionBias(nn.Module):
|
||||||
is_small = n < max_exact
|
is_small = n < max_exact
|
||||||
|
|
||||||
val_if_large = max_exact + (
|
val_if_large = max_exact + (
|
||||||
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
torch.log(n.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
).to(torch.long)
|
).to(torch.long)
|
||||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
val_if_large = torch.min(
|
||||||
|
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
ret += torch.where(is_small, n, val_if_large)
|
ret += torch.where(is_small, n, val_if_large)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def compute_bias(
|
def compute_bias(self, qlen, klen, step=None):
|
||||||
self,
|
|
||||||
qlen,
|
|
||||||
klen,
|
|
||||||
step=None
|
|
||||||
):
|
|
||||||
step = 0 if step is None else step
|
step = 0 if step is None else step
|
||||||
context_position = torch.arange(step, step + qlen, dtype=torch.long,
|
context_position = torch.arange(
|
||||||
device=self.relative_attention_bias.weight.device)[:, None]
|
step,
|
||||||
memory_position = torch.arange(klen, dtype=torch.long,
|
step + qlen,
|
||||||
device=self.relative_attention_bias.weight.device)[None, :]
|
dtype=torch.long,
|
||||||
|
device=self.relative_attention_bias.weight.device,
|
||||||
|
)[:, None]
|
||||||
|
memory_position = torch.arange(
|
||||||
|
klen, dtype=torch.long, device=self.relative_attention_bias.weight.device
|
||||||
|
)[None, :]
|
||||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||||
|
|
||||||
rp_bucket = self._relative_position_bucket(
|
rp_bucket = self._relative_position_bucket(
|
||||||
|
@ -67,16 +65,18 @@ class RelativePositionBias(nn.Module):
|
||||||
num_buckets=self.num_buckets,
|
num_buckets=self.num_buckets,
|
||||||
)
|
)
|
||||||
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
values = self.relative_attention_bias(
|
||||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
rp_bucket
|
||||||
|
) # shape (qlen, klen, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(
|
||||||
|
0
|
||||||
|
) # shape (1, num_heads, qlen, klen)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def forward(
|
def forward(self, batch_size, qlen, klen, step=None):
|
||||||
self,
|
|
||||||
batch_size,
|
|
||||||
qlen,
|
|
||||||
klen,
|
|
||||||
step=None
|
|
||||||
):
|
|
||||||
# shape (batch * num_heads, qlen, klen)
|
# shape (batch * num_heads, qlen, klen)
|
||||||
return self.compute_bias(qlen, klen, step).repeat(batch_size, 1, 1, 1).view(-1, qlen, klen)
|
return (
|
||||||
|
self.compute_bias(qlen, klen, step)
|
||||||
|
.repeat(batch_size, 1, 1, 1)
|
||||||
|
.view(-1, qlen, klen)
|
||||||
|
)
|
||||||
|
|
|
@ -18,9 +18,9 @@ import torch.distributed as dist
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module, ModuleList
|
from torch.nn import Module, ModuleList
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fairseq.modules.moe import MOELayer
|
from fairseq.modules.moe import MOELayer
|
||||||
|
|
||||||
has_fairseq = True
|
has_fairseq = True
|
||||||
Base = MOELayer
|
Base = MOELayer
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
@ -81,8 +81,10 @@ def get_moe_group(moe_expert_count):
|
||||||
else:
|
else:
|
||||||
assert world_size % moe_expert_count == 0
|
assert world_size % moe_expert_count == 0
|
||||||
ranks_per_group = world_size // moe_expert_count
|
ranks_per_group = world_size // moe_expert_count
|
||||||
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)]
|
moe_groups = [
|
||||||
for i in range(moe_expert_count)]
|
[i + j * moe_expert_count for j in range(ranks_per_group)]
|
||||||
|
for i in range(moe_expert_count)
|
||||||
|
]
|
||||||
|
|
||||||
get_moe_group._moe_group_idx = moe_groups
|
get_moe_group._moe_group_idx = moe_groups
|
||||||
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
get_moe_group._moe_groups = [dist.new_group(g) for g in moe_groups]
|
||||||
|
@ -105,11 +107,15 @@ def get_all2all_group(moe_expert_count):
|
||||||
else:
|
else:
|
||||||
assert world_size % moe_expert_count == 0
|
assert world_size % moe_expert_count == 0
|
||||||
ranks_per_group = world_size // moe_expert_count
|
ranks_per_group = world_size // moe_expert_count
|
||||||
all2all_groups = [[i * moe_expert_count + j for j in range(moe_expert_count)]
|
all2all_groups = [
|
||||||
for i in range(ranks_per_group)]
|
[i * moe_expert_count + j for j in range(moe_expert_count)]
|
||||||
|
for i in range(ranks_per_group)
|
||||||
|
]
|
||||||
|
|
||||||
get_all2all_group._all2all_group_idx = all2all_groups
|
get_all2all_group._all2all_group_idx = all2all_groups
|
||||||
get_all2all_group._all2all_groups = [dist.new_group(g) for g in all2all_groups]
|
get_all2all_group._all2all_groups = [
|
||||||
|
dist.new_group(g) for g in all2all_groups
|
||||||
|
]
|
||||||
|
|
||||||
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
my_group_idx = _find_my_group_index(get_all2all_group._all2all_group_idx)
|
||||||
return get_all2all_group._all2all_groups[my_group_idx]
|
return get_all2all_group._all2all_groups[my_group_idx]
|
||||||
|
@ -133,12 +139,7 @@ class MOELayer(Base):
|
||||||
expert network
|
expert network
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, gate, experts, args):
|
||||||
self,
|
|
||||||
gate,
|
|
||||||
experts,
|
|
||||||
args
|
|
||||||
):
|
|
||||||
if has_fairseq:
|
if has_fairseq:
|
||||||
super(Base, self).__init__()
|
super(Base, self).__init__()
|
||||||
else:
|
else:
|
||||||
|
@ -163,9 +164,13 @@ class MOELayer(Base):
|
||||||
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
|
def forward(self, *input: Tensor, input_padding_mask=None, **kwargs: Any) -> Tensor:
|
||||||
assert len(input) == 1, "only single input Tensor supported"
|
assert len(input) == 1, "only single input Tensor supported"
|
||||||
input = input[0]
|
input = input[0]
|
||||||
assert len(input.shape) == 3, "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
|
assert (
|
||||||
|
len(input.shape) == 3
|
||||||
|
), "input Tensor must have dimensions: (s)equence, (t)oken, (m)odel"
|
||||||
if input_padding_mask is not None:
|
if input_padding_mask is not None:
|
||||||
assert len(input_padding_mask.shape) == 2, "input Tensor must have dimensions: (s)equence, (t)oken"
|
assert (
|
||||||
|
len(input_padding_mask.shape) == 2
|
||||||
|
), "input Tensor must have dimensions: (s)equence, (t)oken"
|
||||||
assert input_padding_mask.shape[0] == input.shape[0]
|
assert input_padding_mask.shape[0] == input.shape[0]
|
||||||
assert input_padding_mask.shape[1] == input.shape[1]
|
assert input_padding_mask.shape[1] == input.shape[1]
|
||||||
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
|
# assert input.shape[0] % len(self.experts) == 0, "num tokens must be order of number of local experts"
|
||||||
|
@ -174,81 +179,120 @@ class MOELayer(Base):
|
||||||
d_model = input.shape[2]
|
d_model = input.shape[2]
|
||||||
# Pad to expected batch size
|
# Pad to expected batch size
|
||||||
input_shape = list(input.shape)
|
input_shape = list(input.shape)
|
||||||
expected_bsz = getattr(self.args, 'batch_size', 0) if self.training else getattr(self.args, 'batch_size_valid', 0)
|
expected_bsz = (
|
||||||
|
getattr(self.args, "batch_size", 0)
|
||||||
|
if self.training
|
||||||
|
else getattr(self.args, "batch_size_valid", 0)
|
||||||
|
)
|
||||||
# This indicates that --batch-size or --max-sentences is not specified
|
# This indicates that --batch-size or --max-sentences is not specified
|
||||||
if expected_bsz is None:
|
if expected_bsz is None:
|
||||||
expected_bsz = 0
|
expected_bsz = 0
|
||||||
# Note: Padding is not necessary at generation time at present
|
# Note: Padding is not necessary at generation time at present
|
||||||
# because all DDP workers process the same batch. Also, batch size at generation time
|
# because all DDP workers process the same batch. Also, batch size at generation time
|
||||||
# can be different from that present in the checkpoint state
|
# can be different from that present in the checkpoint state
|
||||||
if not self.in_generation and expected_bsz != 0 and input_shape[0] != expected_bsz:
|
if (
|
||||||
logger.warning(f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})")
|
not self.in_generation
|
||||||
|
and expected_bsz != 0
|
||||||
|
and input_shape[0] != expected_bsz
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"padding batch with unexpected size {input_shape[0]} (expected: {expected_bsz})"
|
||||||
|
)
|
||||||
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
|
assert input_shape[0] < expected_bsz, f"{input_shape[0]} < {expected_bsz}"
|
||||||
padded_input = torch.zeros(
|
padded_input = torch.zeros(
|
||||||
(expected_bsz, input_shape[1], input_shape[2]),
|
(expected_bsz, input_shape[1], input_shape[2]),
|
||||||
dtype=input.dtype, layout=input.layout, device=input.device)
|
dtype=input.dtype,
|
||||||
padded_input[:input_shape[0], :, :] = input
|
layout=input.layout,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
padded_input[: input_shape[0], :, :] = input
|
||||||
input = padded_input
|
input = padded_input
|
||||||
|
|
||||||
padded_input_padding_mask = torch.ones(
|
padded_input_padding_mask = torch.ones(
|
||||||
(expected_bsz, input_shape[1], ), dtype=torch.bool, device=input.device
|
(
|
||||||
|
expected_bsz,
|
||||||
|
input_shape[1],
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=input.device,
|
||||||
)
|
)
|
||||||
if input_padding_mask is not None:
|
if input_padding_mask is not None:
|
||||||
padded_input_padding_mask[:input_shape[0], :] = input_padding_mask
|
padded_input_padding_mask[: input_shape[0], :] = input_padding_mask
|
||||||
else:
|
else:
|
||||||
padded_input_padding_mask[:input_shape[0], :] = False
|
padded_input_padding_mask[: input_shape[0], :] = False
|
||||||
input_padding_mask = padded_input_padding_mask
|
input_padding_mask = padded_input_padding_mask
|
||||||
|
|
||||||
# Reshape into S tokens by dropping sequence dimension.
|
# Reshape into S tokens by dropping sequence dimension.
|
||||||
reshaped_input = input.reshape(-1, d_model)
|
reshaped_input = input.reshape(-1, d_model)
|
||||||
reshaped_input_shape = reshaped_input.shape
|
reshaped_input_shape = reshaped_input.shape
|
||||||
reshaped_input_padding_mask = input_padding_mask.reshape(-1) if input_padding_mask is not None else None
|
reshaped_input_padding_mask = (
|
||||||
|
input_padding_mask.reshape(-1) if input_padding_mask is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
# Doing padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||||
# Pro of --max-tokens: more flexible for MT variable sequence lengths
|
# Pro of --max-tokens: more flexible for MT variable sequence lengths
|
||||||
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
|
# Con of --max-tokens: extra all-reduce needed to figure out optimal padding without running OOM
|
||||||
if expected_bsz == 0:
|
if expected_bsz == 0:
|
||||||
expected_dim = reshaped_input_shape[0] * torch.ones((1,), dtype=torch.long, device=input.device)
|
expected_dim = reshaped_input_shape[0] * torch.ones(
|
||||||
|
(1,), dtype=torch.long, device=input.device
|
||||||
|
)
|
||||||
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
|
dist.all_reduce(expected_dim, group=dist.group.WORLD, op=dist.ReduceOp.MAX)
|
||||||
expected_dim = int(expected_dim.item())
|
expected_dim = int(expected_dim.item())
|
||||||
padded_input = torch.zeros(
|
padded_input = torch.zeros(
|
||||||
(expected_dim, reshaped_input_shape[1]),
|
(expected_dim, reshaped_input_shape[1]),
|
||||||
dtype=input.dtype, layout=input.layout, device=input.device)
|
dtype=input.dtype,
|
||||||
padded_input[:reshaped_input_shape[0], :] = reshaped_input
|
layout=input.layout,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
padded_input[: reshaped_input_shape[0], :] = reshaped_input
|
||||||
reshaped_input = padded_input
|
reshaped_input = padded_input
|
||||||
|
|
||||||
padded_input_padding_mask = torch.ones(
|
padded_input_padding_mask = torch.ones(
|
||||||
(expected_dim,), dtype=torch.bool, device=padded_input.device
|
(expected_dim,), dtype=torch.bool, device=padded_input.device
|
||||||
)
|
)
|
||||||
if reshaped_input_padding_mask is not None:
|
if reshaped_input_padding_mask is not None:
|
||||||
padded_input_padding_mask[:reshaped_input_shape[0]] = reshaped_input_padding_mask
|
padded_input_padding_mask[
|
||||||
|
: reshaped_input_shape[0]
|
||||||
|
] = reshaped_input_padding_mask
|
||||||
else:
|
else:
|
||||||
padded_input_padding_mask[:reshaped_input_shape[0]] = False
|
padded_input_padding_mask[: reshaped_input_shape[0]] = False
|
||||||
reshaped_input_padding_mask = padded_input_padding_mask
|
reshaped_input_padding_mask = padded_input_padding_mask
|
||||||
|
|
||||||
if has_tutel:
|
if has_tutel:
|
||||||
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(reshaped_input, reshaped_input_padding_mask)
|
l_aux, self.metadata, C, E, indices_, locations_, gates_ = self.gate(
|
||||||
|
reshaped_input, reshaped_input_padding_mask
|
||||||
|
)
|
||||||
S, M = reshaped_input.size(0), reshaped_input.size(1)
|
S, M = reshaped_input.size(0), reshaped_input.size(1)
|
||||||
|
|
||||||
if not hasattr(self, '_tutel_dispatcher'):
|
if not hasattr(self, "_tutel_dispatcher"):
|
||||||
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
|
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
|
||||||
|
E, C, M, dispatch_dtype=reshaped_input.dtype
|
||||||
|
)
|
||||||
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
|
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
|
||||||
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
|
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
|
||||||
else:
|
else:
|
||||||
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(reshaped_input, reshaped_input_padding_mask)
|
l_aux, combine_weights, dispatch_mask, self.metadata = self.gate(
|
||||||
|
reshaped_input, reshaped_input_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
dispatch_mask = dispatch_mask.to(input.dtype).permute(1, 2, 0) # S,E,C -> E,C,S
|
dispatch_mask = dispatch_mask.to(input.dtype).permute(
|
||||||
|
1, 2, 0
|
||||||
|
) # S,E,C -> E,C,S
|
||||||
E, C, S = dispatch_mask.size()
|
E, C, S = dispatch_mask.size()
|
||||||
M = reshaped_input.size(1)
|
M = reshaped_input.size(1)
|
||||||
assert reshaped_input.size() == (S, M)
|
assert reshaped_input.size() == (S, M)
|
||||||
# einsum("sec,sm->ecm")
|
# einsum("sec,sm->ecm")
|
||||||
dispatched_input = torch.mm(dispatch_mask.view(E*C, S), reshaped_input) # -> (E*C),M
|
dispatched_input = torch.mm(
|
||||||
|
dispatch_mask.view(E * C, S), reshaped_input
|
||||||
|
) # -> (E*C),M
|
||||||
|
|
||||||
if self.all2all_size > 1:
|
if self.all2all_size > 1:
|
||||||
dispatched_input = self.all_to_all_wrapper(dispatched_input)
|
dispatched_input = self.all_to_all_wrapper(dispatched_input)
|
||||||
|
|
||||||
# Re-shape after all-to-all: ecm -> gecm
|
# Re-shape after all-to-all: ecm -> gecm
|
||||||
dispatched_input = dispatched_input.reshape(self.all2all_size, self.num_local_experts, -1, d_model)
|
dispatched_input = dispatched_input.reshape(
|
||||||
|
self.all2all_size, self.num_local_experts, -1, d_model
|
||||||
|
)
|
||||||
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
|
chunks = dispatched_input.chunk(self.num_local_experts, dim=1)
|
||||||
expert_outputs = []
|
expert_outputs = []
|
||||||
for chunk, expert in zip(chunks, self.experts):
|
for chunk, expert in zip(chunks, self.experts):
|
||||||
|
@ -259,18 +303,24 @@ class MOELayer(Base):
|
||||||
expert_output = self.all_to_all_wrapper(expert_output)
|
expert_output = self.all_to_all_wrapper(expert_output)
|
||||||
|
|
||||||
# Re-shape back: gecm -> ecm
|
# Re-shape back: gecm -> ecm
|
||||||
expert_output = expert_output.reshape(self.all2all_size * self.num_local_experts, -1, d_model)
|
expert_output = expert_output.reshape(
|
||||||
|
self.all2all_size * self.num_local_experts, -1, d_model
|
||||||
|
)
|
||||||
|
|
||||||
if has_tutel:
|
if has_tutel:
|
||||||
combined_output = self._tutel_dispatcher.decode(expert_output.view(E*C, M))
|
combined_output = self._tutel_dispatcher.decode(
|
||||||
|
expert_output.view(E * C, M)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# einsum("sec,ecm->sm")
|
# einsum("sec,ecm->sm")
|
||||||
combined_output = combine_weights.view(S, E*C).mm(expert_output.view(E*C, M))
|
combined_output = combine_weights.view(S, E * C).mm(
|
||||||
|
expert_output.view(E * C, M)
|
||||||
|
)
|
||||||
|
|
||||||
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
# Remove padding here when --max-tokens is specified and not --batch-size or --max-sentences
|
||||||
combined_output = combined_output[:reshaped_input_shape[0], :]
|
combined_output = combined_output[: reshaped_input_shape[0], :]
|
||||||
combined_output = combined_output.reshape(input.shape)
|
combined_output = combined_output.reshape(input.shape)
|
||||||
combined_output = combined_output[:input_shape[0], :, :]
|
combined_output = combined_output[: input_shape[0], :, :]
|
||||||
|
|
||||||
self.record_all_to_all_stats()
|
self.record_all_to_all_stats()
|
||||||
|
|
||||||
|
@ -280,7 +330,7 @@ class MOELayer(Base):
|
||||||
self.in_generation = True
|
self.in_generation = True
|
||||||
|
|
||||||
def all_to_all_wrapper(self, input: Tensor):
|
def all_to_all_wrapper(self, input: Tensor):
|
||||||
dummy_a2a = getattr(self.args, 'dummy_a2a', False)
|
dummy_a2a = getattr(self.args, "dummy_a2a", False)
|
||||||
if dummy_a2a:
|
if dummy_a2a:
|
||||||
input = input.contiguous()
|
input = input.contiguous()
|
||||||
output = input.detach().clone()
|
output = input.detach().clone()
|
||||||
|
@ -294,13 +344,13 @@ class MOELayer(Base):
|
||||||
output = _AllToAll.apply(self.all2all_group, input)
|
output = _AllToAll.apply(self.all2all_group, input)
|
||||||
cuda_end.record()
|
cuda_end.record()
|
||||||
cpu_end = time.time() * 1000
|
cpu_end = time.time() * 1000
|
||||||
self.a2a_cpu_time_ms += (cpu_end - cpu_start)
|
self.a2a_cpu_time_ms += cpu_end - cpu_start
|
||||||
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
|
self.a2a_cuda_event_intervals.append((cuda_start, cuda_end))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def record_all_to_all_stats(self):
|
def record_all_to_all_stats(self):
|
||||||
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
|
# controlled via an argument as we want to minimize any impact from torch.cuda.synchronize()
|
||||||
record_a2a_perf_stats = getattr(self.args, 'record_a2a_perf_stats', False)
|
record_a2a_perf_stats = getattr(self.args, "record_a2a_perf_stats", False)
|
||||||
if record_a2a_perf_stats:
|
if record_a2a_perf_stats:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
|
self.metadata["all_to_all_cpu_time_ms"] = self.a2a_cpu_time_ms
|
||||||
|
|
|
@ -13,14 +13,14 @@
|
||||||
# NOTE: This is a mirror of the code in
|
# NOTE: This is a mirror of the code in
|
||||||
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
# https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/moe
|
||||||
|
|
||||||
from typing import Callable, Dict, Tuple, Optional
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
from torch import Tensor
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from .moe_layer import has_tutel, fused_cumsum_sub_one
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .moe_layer import fused_cumsum_sub_one, has_tutel
|
||||||
|
|
||||||
# use a fixed temperature to compute balance loss
|
# use a fixed temperature to compute balance loss
|
||||||
TEMPERATURE_FOR_L_UAX = 0.07
|
TEMPERATURE_FOR_L_UAX = 0.07
|
||||||
|
@ -65,13 +65,22 @@ def top1gating(
|
||||||
indices1_s = torch.argmax(gates, dim=1)
|
indices1_s = torch.argmax(gates, dim=1)
|
||||||
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
|
mask1 = one_hot(indices1_s, num_classes=num_experts, unsqueeze_indices=True)
|
||||||
if input_mask is not None and input_mask.any():
|
if input_mask is not None and input_mask.any():
|
||||||
nonpadding = ~ input_mask
|
nonpadding = ~input_mask
|
||||||
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||||
|
|
||||||
# for logging (percent of tokens routed to each expert)
|
# for logging (percent of tokens routed to each expert)
|
||||||
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
expert1_hist = (
|
||||||
|
100
|
||||||
|
* torch.histc(
|
||||||
|
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
|
||||||
|
)
|
||||||
|
/ num_tokens
|
||||||
|
)
|
||||||
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||||
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
expert1_hist = (
|
||||||
|
torch.sort(expert1_hist, dim=0, descending=True).values
|
||||||
|
+ torch.finfo(torch.float32).tiny
|
||||||
|
)
|
||||||
|
|
||||||
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||||
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||||
|
@ -91,7 +100,21 @@ def top1gating(
|
||||||
|
|
||||||
if has_tutel:
|
if has_tutel:
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
return l_aux, metadata, capacity, num_experts, [indices1_s, ], [locations1_s, ], [gates1_s, ]
|
return (
|
||||||
|
l_aux,
|
||||||
|
metadata,
|
||||||
|
capacity,
|
||||||
|
num_experts,
|
||||||
|
[
|
||||||
|
indices1_s,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
locations1_s,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
gates1_s,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Remove locations outside capacity from mask
|
# Remove locations outside capacity from mask
|
||||||
mask1 = mask1 * torch.lt(locations1, capacity)
|
mask1 = mask1 * torch.lt(locations1, capacity)
|
||||||
|
@ -104,7 +127,8 @@ def top1gating(
|
||||||
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
locations1_sc = one_hot(locations1_s, num_classes=capacity, unsqueeze_indices=True)
|
||||||
combine1_sec = torch.bmm(
|
combine1_sec = torch.bmm(
|
||||||
# einsum("se,sc->sec")
|
# einsum("se,sc->sec")
|
||||||
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
gates1.unsqueeze(-1),
|
||||||
|
locations1_sc.to(gates1.dtype).unsqueeze(1),
|
||||||
)
|
)
|
||||||
dispatch_mask = combine1_sec.bool()
|
dispatch_mask = combine1_sec.bool()
|
||||||
if use_fp32:
|
if use_fp32:
|
||||||
|
@ -218,10 +242,10 @@ def one_hot(indices: torch.Tensor, num_classes: int, unsqueeze_indices=False) ->
|
||||||
if unsqueeze_indices:
|
if unsqueeze_indices:
|
||||||
indices = indices.unsqueeze(-1)
|
indices = indices.unsqueeze(-1)
|
||||||
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
|
assert indices.shape[-1] == 1, "last dimension of indices must be have size 1"
|
||||||
output = torch.zeros(indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype)
|
output = torch.zeros(
|
||||||
output.scatter_(
|
indices.shape[:-1] + (num_classes,), device=indices.device, dtype=indices.dtype
|
||||||
len(output.shape) - 1, indices, 1
|
|
||||||
)
|
)
|
||||||
|
output.scatter_(len(output.shape) - 1, indices, 1)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -235,7 +259,7 @@ def top2gating(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
input_mask: Optional[torch.Tensor] = None,
|
input_mask: Optional[torch.Tensor] = None,
|
||||||
use_fp32=False,
|
use_fp32=False,
|
||||||
second_expert_policy='sampling',
|
second_expert_policy="sampling",
|
||||||
normalize_gate_prob_before_dropping=False,
|
normalize_gate_prob_before_dropping=False,
|
||||||
eval_mode=False,
|
eval_mode=False,
|
||||||
moe_eval_capacity_token_fraction=0.25,
|
moe_eval_capacity_token_fraction=0.25,
|
||||||
|
@ -260,7 +284,7 @@ def top2gating(
|
||||||
# Create a mask for 1st's expert per token
|
# Create a mask for 1st's expert per token
|
||||||
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
|
indices1_s = torch.argmax(gates, dim=1, keepdim=True)
|
||||||
mask1 = one_hot(indices1_s, num_experts)
|
mask1 = one_hot(indices1_s, num_experts)
|
||||||
if second_expert_policy == 'sampling':
|
if second_expert_policy == "sampling":
|
||||||
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||||
|
@ -281,13 +305,13 @@ def top2gating(
|
||||||
gates1_s = gates1_s / denom_s
|
gates1_s = gates1_s / denom_s
|
||||||
gates2_s = gates2_s / denom_s
|
gates2_s = gates2_s / denom_s
|
||||||
|
|
||||||
if second_expert_policy == 'random':
|
if second_expert_policy == "random":
|
||||||
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
|
sampled = (2 * gates2_s) > torch.rand_like(gates2_s)
|
||||||
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
|
mask2 = mask2 * sampled.repeat(num_experts, 1).transpose(1, 0)
|
||||||
|
|
||||||
# Compute locations in capacity buffer
|
# Compute locations in capacity buffer
|
||||||
if input_mask is not None and input_mask.any():
|
if input_mask is not None and input_mask.any():
|
||||||
nonpadding = ~ input_mask
|
nonpadding = ~input_mask
|
||||||
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
mask1 = mask1 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||||
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
mask2 = mask2 * nonpadding.unsqueeze(-1).to(mask1.dtype)
|
||||||
|
|
||||||
|
@ -296,15 +320,22 @@ def top2gating(
|
||||||
importance_scores = -1 * gates.max(dim=1)[0]
|
importance_scores = -1 * gates.max(dim=1)[0]
|
||||||
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
|
sorted_mask1 = mask1[importance_scores.argsort(dim=0)]
|
||||||
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
|
sorted_cumsum1 = fused_cumsum_sub_one(sorted_mask1) * sorted_mask1
|
||||||
importance_sorted_locations1 = sorted_cumsum1[importance_scores.argsort(dim=0).argsort(dim=0)]
|
importance_sorted_locations1 = sorted_cumsum1[
|
||||||
|
importance_scores.argsort(dim=0).argsort(dim=0)
|
||||||
|
]
|
||||||
|
|
||||||
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
|
sorted_mask2 = mask2[importance_scores.argsort(dim=0)]
|
||||||
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
|
sorted_cumsum2 = fused_cumsum_sub_one(sorted_mask2) * sorted_mask2
|
||||||
importance_sorted_locations2 = sorted_cumsum2[importance_scores.argsort(dim=0).argsort(dim=0)]
|
importance_sorted_locations2 = sorted_cumsum2[
|
||||||
|
importance_scores.argsort(dim=0).argsort(dim=0)
|
||||||
|
]
|
||||||
|
|
||||||
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
importance_sorted_locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||||
|
|
||||||
locations1, locations2 = importance_sorted_locations1, importance_sorted_locations2
|
locations1, locations2 = (
|
||||||
|
importance_sorted_locations1,
|
||||||
|
importance_sorted_locations2,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
locations1 = fused_cumsum_sub_one(mask1)
|
locations1 = fused_cumsum_sub_one(mask1)
|
||||||
locations2 = fused_cumsum_sub_one(mask2)
|
locations2 = fused_cumsum_sub_one(mask2)
|
||||||
|
@ -318,8 +349,12 @@ def top2gating(
|
||||||
l_aux = l_aux * num_experts * num_experts
|
l_aux = l_aux * num_experts * num_experts
|
||||||
|
|
||||||
# for logging purposes
|
# for logging purposes
|
||||||
metadata["overflow_expert1"] = 100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
|
metadata["overflow_expert1"] = (
|
||||||
metadata["overflow_expert2"] = 100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
|
100 * torch.sum(mask1 * torch.ge(locations1, capacity)) / torch.sum(mask1)
|
||||||
|
)
|
||||||
|
metadata["overflow_expert2"] = (
|
||||||
|
100 * torch.sum(mask2 * torch.ge(locations2, capacity)) / torch.sum(mask2)
|
||||||
|
)
|
||||||
|
|
||||||
# Remove locations outside capacity from mask
|
# Remove locations outside capacity from mask
|
||||||
mask1_, mask2_ = mask1, mask2
|
mask1_, mask2_ = mask1, mask2
|
||||||
|
@ -327,13 +362,31 @@ def top2gating(
|
||||||
mask2 = mask2 * torch.lt(locations2, capacity)
|
mask2 = mask2 * torch.lt(locations2, capacity)
|
||||||
|
|
||||||
# for logging (percent of tokens routed to each expert)
|
# for logging (percent of tokens routed to each expert)
|
||||||
expert1_hist = 100 * torch.histc((indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
expert1_hist = (
|
||||||
|
100
|
||||||
|
* torch.histc(
|
||||||
|
(indices1_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
|
||||||
|
)
|
||||||
|
/ num_tokens
|
||||||
|
)
|
||||||
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
metadata["unused_expert1_count"] = (expert1_hist == 0).sum()
|
||||||
expert1_hist = torch.sort(expert1_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
expert1_hist = (
|
||||||
|
torch.sort(expert1_hist, dim=0, descending=True).values
|
||||||
|
+ torch.finfo(torch.float32).tiny
|
||||||
|
)
|
||||||
|
|
||||||
expert2_hist = 100 * torch.histc((indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts) / num_tokens
|
expert2_hist = (
|
||||||
|
100
|
||||||
|
* torch.histc(
|
||||||
|
(indices2_s.squeeze() + 1), bins=num_experts, min=1, max=num_experts
|
||||||
|
)
|
||||||
|
/ num_tokens
|
||||||
|
)
|
||||||
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
|
metadata["unused_expert2_count"] = (expert2_hist == 0).sum()
|
||||||
expert2_hist = torch.sort(expert2_hist, dim=0, descending=True).values + torch.finfo(torch.float32).tiny
|
expert2_hist = (
|
||||||
|
torch.sort(expert2_hist, dim=0, descending=True).values
|
||||||
|
+ torch.finfo(torch.float32).tiny
|
||||||
|
)
|
||||||
|
|
||||||
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
sample_count = max(math.ceil(num_experts * SAMPLE_FRACTION), 1)
|
||||||
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
metadata["expert1_balance_top"] = expert1_hist[:sample_count].sum()
|
||||||
|
@ -355,8 +408,15 @@ def top2gating(
|
||||||
if has_tutel:
|
if has_tutel:
|
||||||
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
locations1_s = torch.sum(locations1 * mask1_, dim=1)
|
||||||
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
locations2_s = torch.sum(locations2 * mask2_, dim=1)
|
||||||
return l_aux, metadata, capacity, num_experts, \
|
return (
|
||||||
[indices1_s, indices2_s], [locations1_s, locations2_s], [gates1_s, gates2_s]
|
l_aux,
|
||||||
|
metadata,
|
||||||
|
capacity,
|
||||||
|
num_experts,
|
||||||
|
[indices1_s, indices2_s],
|
||||||
|
[locations1_s, locations2_s],
|
||||||
|
[gates1_s, gates2_s],
|
||||||
|
)
|
||||||
|
|
||||||
# Store the capacity location for each token
|
# Store the capacity location for each token
|
||||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||||
|
@ -369,11 +429,13 @@ def top2gating(
|
||||||
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
|
locations2_sc = one_hot(locations2_s, num_classes=capacity, unsqueeze_indices=True)
|
||||||
combine1_sec = torch.bmm(
|
combine1_sec = torch.bmm(
|
||||||
# einsum("se,sc->sec")
|
# einsum("se,sc->sec")
|
||||||
gates1.unsqueeze(-1), locations1_sc.to(gates1.dtype).unsqueeze(1)
|
gates1.unsqueeze(-1),
|
||||||
|
locations1_sc.to(gates1.dtype).unsqueeze(1),
|
||||||
)
|
)
|
||||||
combine2_sec = torch.bmm(
|
combine2_sec = torch.bmm(
|
||||||
# einsum("se,sc->sec")
|
# einsum("se,sc->sec")
|
||||||
gates2.unsqueeze(-1), locations2_sc.to(gates2.dtype).unsqueeze(1)
|
gates2.unsqueeze(-1),
|
||||||
|
locations2_sc.to(gates2.dtype).unsqueeze(1),
|
||||||
)
|
)
|
||||||
combine_weights = combine1_sec + combine2_sec
|
combine_weights = combine1_sec + combine2_sec
|
||||||
dispatch_mask = combine_weights.bool()
|
dispatch_mask = combine_weights.bool()
|
||||||
|
@ -406,7 +468,7 @@ class Top2Gate(torch.nn.Module):
|
||||||
model_dim: int,
|
model_dim: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
use_fp32=False,
|
use_fp32=False,
|
||||||
second_expert_policy='sampling',
|
second_expert_policy="sampling",
|
||||||
normalize_gate_prob_before_dropping=False,
|
normalize_gate_prob_before_dropping=False,
|
||||||
moe_eval_capacity_token_fraction=0.25,
|
moe_eval_capacity_token_fraction=0.25,
|
||||||
batch_prioritized_routing=False,
|
batch_prioritized_routing=False,
|
||||||
|
|
|
@ -3,37 +3,35 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torchscale.architecture.encoder import Encoder
|
from torchscale.architecture.encoder import Encoder
|
||||||
from torchscale.component.embedding import VisionEmbedding, TextEmbedding, PositionalEmbedding
|
from torchscale.component.embedding import (
|
||||||
|
PositionalEmbedding,
|
||||||
|
TextEmbedding,
|
||||||
|
VisionEmbedding,
|
||||||
|
)
|
||||||
from torchscale.component.multiway_network import MultiwayWrapper
|
from torchscale.component.multiway_network import MultiwayWrapper
|
||||||
|
|
||||||
|
|
||||||
class BEiT3(nn.Module):
|
class BEiT3(nn.Module):
|
||||||
|
|
||||||
def __init__(self, args, **kwargs):
|
def __init__(self, args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
assert args.multiway
|
assert args.multiway
|
||||||
assert args.vocab_size > 0
|
assert args.vocab_size > 0
|
||||||
assert not args.share_encoder_input_output_embed
|
assert not args.share_encoder_input_output_embed
|
||||||
self.text_embed = TextEmbedding(
|
self.text_embed = TextEmbedding(args.vocab_size, args.encoder_embed_dim)
|
||||||
args.vocab_size,
|
|
||||||
args.encoder_embed_dim
|
|
||||||
)
|
|
||||||
self.vision_embed = VisionEmbedding(
|
self.vision_embed = VisionEmbedding(
|
||||||
args.img_size,
|
args.img_size,
|
||||||
args.patch_size,
|
args.patch_size,
|
||||||
args.in_chans,
|
args.in_chans,
|
||||||
args.encoder_embed_dim,
|
args.encoder_embed_dim,
|
||||||
contain_mask_token=True,
|
contain_mask_token=True,
|
||||||
prepend_cls_token=True
|
prepend_cls_token=True,
|
||||||
)
|
)
|
||||||
embed_positions = MultiwayWrapper(
|
embed_positions = MultiwayWrapper(
|
||||||
args,
|
args,
|
||||||
PositionalEmbedding(
|
PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim),
|
||||||
args.max_source_positions,
|
|
||||||
args.encoder_embed_dim
|
|
||||||
),
|
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
|
@ -71,7 +69,7 @@ class BEiT3(nn.Module):
|
||||||
encoder_padding_mask = torch.cat(
|
encoder_padding_mask = torch.cat(
|
||||||
[
|
[
|
||||||
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
|
torch.zeros(x1.shape[:-1]).to(x1.device).bool(),
|
||||||
text_padding_position
|
text_padding_position,
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user