This commit is contained in:
Li Dong 2023-07-24 14:30:13 +08:00
parent 89d8c6de9e
commit bf65397b26
4 changed files with 1050 additions and 0 deletions

View File

@ -0,0 +1,377 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import distributed_utils, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from omegaconf import II
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
relu_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
decoder_input_dim: int = field(
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_retention_heads: int = field(
default=2, metadata={"help": "num decoder retention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each decoder block"}
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
},
)
moe_freq: int = field(
default=0,
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
)
moe_expert_count: int = field(
default=0, metadata={"help": "Number of experts in each MoE Layer"}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
)
moe_second_expert_policy: str = field(
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
)
moe_expert_ffn_dim: Optional[int] = field(
default=None, metadata={"help": "MoE expert FFN dimension"}
)
moe_top1_expert: Optional[bool] = field(
default=False, metadata={"help": "Use top1 gate instead of top2"}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
},
)
moe_normalize_expert_grad: Optional[str] = field(
default="world_size",
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
},
)
record_a2a_perf_stats: Optional[bool] = field(
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
)
dummy_a2a: Optional[bool] = field(
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
)
use_xmoe: Optional[bool] = field(
default=False,
)
chunkwise_recurrent: Optional[bool] = field(
default=False,
)
recurrent_chunk_size: Optional[int] = field(
default=512,
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("common.tpu")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
fp16: bool = II("common.fp16")
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
ddp_backend: str = II("distributed_training.ddp_backend")
world_size: int = II("distributed_training.distributed_world_size")
distributed_rank: int = II("distributed_training.distributed_rank")
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
@register_model("retnet", dataclass=LanguageConfig)
class RetNetLanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, "fp16", False)
and not getattr(args, "memory_efficient_fp16", False)
and getattr(args, "ddp_backend", None) != "fully_sharded"
):
assert (
args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = RetNetConfig()
config.override(args)
decoder = LMDecoder(
config,
embed_tokens,
output_projection,
dictionary=task.dictionary,
)
return cls(args, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return Embedding(len(dictionary), embed_dim, dictionary.pad())
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_target_positions
def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
@register_model_architecture("retnet", "retnet_base")
def retnet_base_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, "dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False)
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
@register_model_architecture("retnet", "retnet_medium")
def retnet_medium(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 16)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_xl")
def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_3b")
def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_7b")
def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_13b")
def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_65b")
def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args)

View File

@ -207,3 +207,68 @@ class EncoderDecoderConfig(object):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)
class RetNetConfig(object):
def __init__(self, **kwargs):
self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", 768)
self.decoder_retention_heads = kwargs.pop("decoder_retention_heads", 3)
self.decoder_ffn_embed_dim = kwargs.pop("decoder_ffn_embed_dim", 1536)
self.decoder_layers = kwargs.pop("decoder_layers", 12)
self.decoder_normalize_before = kwargs.pop("decoder_normalize_before", True)
self.activation_fn = kwargs.pop("activation_fn", "gelu")
self.dropout = kwargs.pop("dropout", 0.0)
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
self.moe_freq = kwargs.pop("moe_freq", 0)
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
self.moe_gating_use_fp32 = kwargs.pop("moe_gating_use_fp32", True)
self.moe_eval_capacity_token_fraction = kwargs.pop(
"moe_eval_capacity_token_fraction", 0.25
)
self.moe_second_expert_policy = kwargs.pop("moe_second_expert_policy", "random")
self.moe_normalize_gate_prob_before_dropping = kwargs.pop(
"moe_normalize_gate_prob_before_dropping", False
)
self.use_xmoe = kwargs.pop("use_xmoe", False)
self.rel_pos_buckets = kwargs.pop("rel_pos_buckets", 0)
self.max_rel_pos = kwargs.pop("max_rel_pos", 0)
self.deepnorm = kwargs.pop("deepnorm", False)
self.subln = kwargs.pop("subln", True)
self.multiway = kwargs.pop("multiway", False)
self.share_decoder_input_output_embed = kwargs.pop(
"share_decoder_input_output_embed", False
)
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
self.no_output_layer = kwargs.pop("no_output_layer", False)
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-5)
# Blockwise
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
# Text
self.vocab_size = kwargs.pop("vocab_size", -1)
# Fairscale
self.checkpoint_activations = kwargs.pop("checkpoint_activations", False)
self.fsdp = kwargs.pop("fsdp", False)
self.ddp_rank = kwargs.pop("ddp_rank", 0)
self.xpos_rel_pos = kwargs.pop("xpos_rel_pos", False)
self.xpos_scale_base = kwargs.pop("xpos_scale_base", 512)
if self.deepnorm:
self.decoder_normalize_before = False
self.subln = False
if self.subln:
self.decoder_normalize_before = True
self.deepnorm = False
if self.use_xmoe:
self.moe_normalize_gate_prob_before_dropping = True
self.moe_second_expert_policy = "random"
assert self.moe_freq > 0 and self.moe_expert_count > 0
def override(self, args):
for hp in self.__dict__.keys():
if getattr(args, hp, None) is not None:
self.__dict__[hp] = getattr(args, hp, None)

View File

@ -0,0 +1,403 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import checkpoint_wrapper, wrap
from torchscale.architecture.utils import init_bert_params
from torchscale.component.droppath import DropPath
from torchscale.component.feedforward_network import FeedForwardNetwork, make_experts
from torchscale.component.multiscale_retention import MultiScaleRetention
from torchscale.component.xmoe.moe_layer import MOELayer
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
class RetNetRelPos(nn.Module):
def __init__(self, args):
super().__init__()
angle = 1.0 / (10000 ** torch.linspace(0, 1, args.decoder_embed_dim // args.decoder_retention_heads // 2))
angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
decay = torch.log(1 - 2 ** (-5 - torch.arange(args.decoder_retention_heads, dtype=torch.float)))
self.register_buffer("angle", angle)
self.register_buffer("decay", decay)
self.recurrent_chunk_size = args.recurrent_chunk_size
def forward(self, slen, activate_recurrent=False, chunkwise_recurrent=False):
if activate_recurrent:
sin = torch.sin(self.angle * (slen - 1))
cos = torch.cos(self.angle * (slen - 1))
retention_rel_pos = ((sin, cos), self.decay.exp())
elif chunkwise_recurrent:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
block_index = torch.arange(self.recurrent_chunk_size).to(self.decay)
mask = torch.tril(torch.ones(self.recurrent_chunk_size, self.recurrent_chunk_size).to(self.decay))
mask = torch.masked_fill(block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
scale = mask.sum(dim=-1, keepdim=True).sqrt()
mask = mask / scale
cross_decay = torch.exp(self.decay * self.recurrent_chunk_size)
inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
cross_decay = cross_decay[:, None, None]
inner_decay = inner_decay[:, :, None] / (scale / scale[:, -1, None])
retention_rel_pos = ((sin, cos), (mask, cross_decay, inner_decay))
else:
index = torch.arange(slen).to(self.decay)
sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
mask = torch.tril(torch.ones(slen, slen).to(self.decay))
mask = torch.masked_fill(index[:, None] - index[None, :], ~mask.bool(), float("inf"))
mask = torch.exp(mask * self.decay[:, None, None])
mask = torch.nan_to_num(mask)
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
retention_rel_pos = ((sin, cos), mask)
return retention_rel_pos
class DecoderLayer(nn.Module):
def __init__(
self,
args,
depth,
is_moe_layer=False,
):
super().__init__()
self.args = args
self.embed_dim = args.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(args.dropout)
if args.drop_path_rate > 0:
drop_path_prob = np.linspace(0, args.drop_path_rate, args.decoder_layers)[
depth
]
self.drop_path = DropPath(drop_path_prob)
else:
self.drop_path = None
self.retention = self.build_retention(self.embed_dim, args)
self.normalize_before = args.decoder_normalize_before
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
self.is_moe_layer = is_moe_layer
self.ffn_dim = args.decoder_ffn_embed_dim
if not self.is_moe_layer:
self.ffn = self.build_ffn(
self.embed_dim,
self.args,
)
else:
if args.moe_top1_expert:
gate = Top1Gate(
self.embed_dim,
args.moe_expert_count,
use_fp32=args.moe_gating_use_fp32,
moe_eval_capacity_token_fraction=args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
else:
gate = Top2Gate(
self.embed_dim,
args.moe_expert_count,
args.moe_gating_use_fp32,
args.moe_second_expert_policy,
args.moe_normalize_gate_prob_before_dropping,
args.moe_eval_capacity_token_fraction,
use_xmoe=args.use_xmoe,
)
experts = make_experts(args, self.embed_dim, self.ffn_dim)
self.moe_layer = MOELayer(gate, experts, args)
self.final_layer_norm = LayerNorm(self.embed_dim, eps=args.layernorm_eps)
if args.deepnorm:
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
else:
self.alpha = 1.0
def build_ffn(self, embed_dim, args):
return FeedForwardNetwork(
embed_dim,
self.ffn_dim,
args.activation_fn,
args.dropout,
args.activation_dropout,
args.layernorm_eps,
args.subln,
)
def build_retention(self, embed_dim, args):
return MultiScaleRetention(
args,
embed_dim,
args.decoder_retention_heads,
)
def residual_connection(self, x, residual):
return residual * self.alpha + x
def forward(
self,
x,
incremental_state=None,
chunkwise_recurrent=False,
retention_rel_pos=None,
):
residual = x
if self.normalize_before:
x = self.retention_layer_norm(x)
x = self.retention(
x,
incremental_state=incremental_state,
rel_pos=retention_rel_pos,
chunkwise_recurrent=chunkwise_recurrent,
)
x = self.dropout_module(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.retention_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
if not self.is_moe_layer:
x = self.ffn(x)
l_aux = None
else:
x, l_aux = self.moe_layer(x)
if self.drop_path is not None:
x = self.drop_path(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, l_aux
class RetNetDecoder(nn.Module):
def __init__(
self,
args,
embed_tokens=None,
output_projection=None,
**kwargs
):
super().__init__(**kwargs)
self.args = args
self.dropout_module = torch.nn.Dropout(args.dropout)
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.embed_tokens = embed_tokens
if (
output_projection is None
and not args.no_output_layer
and args.vocab_size > 0
):
self.output_projection = self.build_output_projection(args)
else:
self.output_projection = output_projection
if args.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, eps=args.layernorm_eps)
else:
self.layernorm_embedding = None
self.layers = nn.ModuleList([])
moe_freq = args.moe_freq
for i in range(args.decoder_layers):
is_moe_layer = moe_freq != 0 and (i + 1) % moe_freq == 0
self.layers.append(
self.build_decoder_layer(
args,
depth=i,
is_moe_layer=is_moe_layer,
)
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(embed_dim, eps=args.layernorm_eps)
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(args)
self.chunkwise_recurrent = args.chunkwise_recurrent
self.recurrent_chunk_size = args.recurrent_chunk_size
if args.deepnorm:
init_scale = math.pow(8.0 * args.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if args.subln:
init_scale = math.sqrt(math.log(args.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
def build_output_projection(
self,
args,
):
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = self.embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, args.vocab_size, bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
return output_projection
def build_decoder_layer(
self, args, depth, is_moe_layer=False
):
layer = DecoderLayer(
args,
depth,
is_moe_layer=is_moe_layer,
)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
if args.fsdp:
layer = wrap(layer)
return layer
def forward_embedding(
self,
tokens,
token_embedding=None,
incremental_state=None,
):
if incremental_state is not None and not self.is_first_step(incremental_state):
tokens = tokens[:, -1:]
if token_embedding is None:
token_embedding = self.embed_tokens(tokens)
x = embed = self.embed_scale * token_embedding
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x, embed
def is_first_step(self, incremental_state):
if incremental_state is None:
return False
return incremental_state.get("is_first_step", False)
def forward(
self,
prev_output_tokens,
incremental_state=None,
features_only=False,
return_all_hiddens=False,
token_embeddings=None,
**kwargs
):
# embed tokens
x, _ = self.forward_embedding(
prev_output_tokens, token_embeddings, incremental_state
)
is_first_step = self.is_first_step(incremental_state)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
padding_len = self.recurrent_chunk_size - prev_output_tokens.size(1) % self.recurrent_chunk_size
slen = prev_output_tokens.size(1) + padding_len
x = F.pad(x, (0, 0, 0, padding_len))
else:
slen = prev_output_tokens.size(1)
# relative position
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
# decoder layers
inner_states = [x]
l_aux = []
for idx, layer in enumerate(self.layers):
if incremental_state is None or is_first_step:
if is_first_step and incremental_state is not None:
if idx not in incremental_state:
incremental_state[idx] = {}
else:
if idx not in incremental_state:
incremental_state[idx] = {}
x, l_aux_i = layer(
x,
incremental_state[idx] if incremental_state is not None else None,
retention_rel_pos=retention_rel_pos,
chunkwise_recurrent=self.chunkwise_recurrent,
)
l_aux.append(l_aux_i)
inner_states.append(x)
if self.chunkwise_recurrent and prev_output_tokens.size(1) % self.recurrent_chunk_size != 0:
x = x[:, :prev_output_tokens.size(1), :]
if self.layer_norm is not None:
x = self.layer_norm(x)
if not features_only:
x = self.output_layer(x)
return x, {
"inner_states": inner_states,
"l_aux": l_aux,
"attn": None,
}
def output_layer(self, features):
return self.output_projection(features)

View File

@ -0,0 +1,205 @@
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import torch.nn.functional as F
from torch import nn
try:
from apex.normalization import FusedLayerNorm as LayerNorm
except ModuleNotFoundError:
from torch.nn import LayerNorm
from .multiway_network import MultiwayWrapper
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m
def theta_shift(x, sin, cos):
return (x * cos) + (rotate_every_two(x) * sin)
def get_activation_fn(activation):
if activation == "swish":
return F.silu
elif activation == "gelu":
return F.gelu
else:
raise NotImplementedError
class MultiScaleRetention(nn.Module):
def __init__(
self,
args,
embed_dim,
num_heads,
value_factor=2,
gate_fn="swish",
):
super().__init__()
self.args = args
self.factor = value_factor
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim * self.factor // num_heads
self.key_dim = self.embed_dim // num_heads
self.scaling = self.key_dim ** -0.5
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim * self.factor, bias=True))
self.out_proj = MultiwayWrapper(args, nn.Linear(embed_dim * self.factor, embed_dim, bias=True))
self.group_norm = MultiwayWrapper(args, LayerNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.constant_(self.out_proj.bias, 0.0)
def parallel_forward(self, qr, kr, v, mask):
bsz, tgt_len, embed_dim = v.size()
vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
qk_mat = qk_mat * mask
# invariant after normalization
qk_mat = qk_mat / qk_mat.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)
output = torch.matmul(qk_mat, vr)
output = output.transpose(1, 2)
return output
def recurrent_forward(
self,
qr, kr, v,
decay,
incremental_state
):
bsz = v.size(0)
v = v.view(bsz, self.num_heads, self.head_dim, 1)
kv = kr * v
if "prev_key_value" in incremental_state:
prev_kv = incremental_state["prev_key_value"]
prev_scale = incremental_state["scale"]
scale = prev_scale * decay + 1
kv = prev_kv * (1 - 1 / scale).view(self.num_heads, 1, 1) + kv / scale.view(self.num_heads, 1, 1)
# kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
else:
scale = torch.ones_like(decay)
incremental_state["prev_key_value"] = kv
incremental_state["scale"] = scale
output = torch.sum(qr * kv, dim=3)
return output
def chunk_recurrent_forward(
self,
qr, kr, v,
inner_mask
):
mask, cross_decay, inner_decay = inner_mask
bsz, tgt_len, embed_dim = v.size()
chunk_len = mask.size(1)
num_chunks = tgt_len // chunk_len
assert tgt_len % chunk_len == 0
qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)
kr_t = kr.transpose(-1, -2)
qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
qk_mat = qk_mat * mask
inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
qk_mat = qk_mat / inner_scale
inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
# reduce kv in one chunk
kv = kr_t @ (v * mask[:, -1, :, None])
kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
kv_recurrent = []
cross_scale = []
kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
kv_scale = torch.ones(bsz, self.num_heads, 1, self.head_dim).to(v)
# accumulate kv by loop
for i in range(num_chunks):
kv_recurrent.append(kv_state / kv_scale)
cross_scale.append(kv_scale)
kv_state = kv_state * cross_decay + kv[:, i]
kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).clamp(min=1)
kv_recurrent = torch.stack(kv_recurrent, dim=1)
cross_scale = torch.stack(cross_scale, dim=1)
cross_output = (qr * inner_decay) @ kv_recurrent
output = inner_output / cross_scale + cross_output / inner_scale
output = output.transpose(2, 3)
return output
def forward(
self,
x,
rel_pos,
chunkwise_recurrent=False,
incremental_state=None
):
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
g = self.g_proj(x)
k *= self.scaling
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
if incremental_state is not None:
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
output = self.parallel_forward(qr, kr, v, inner_mask)
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output