torchscale/examples/fairseq/models/retnet.py
2023-07-24 14:30:13 +08:00

378 lines
15 KiB
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from typing import Optional
import torch
from fairseq import distributed_utils, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
FairseqIncrementalDecoder,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
from omegaconf import II
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
@dataclass
class LanguageConfig(FairseqDataclass):
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="relu", metadata={"help": "activation function to use"}
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
relu_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN."}
)
decoder_embed_dim: int = field(
default=512, metadata={"help": "decoder embedding dimension"}
)
decoder_output_dim: int = field(
default=512, metadata={"help": "decoder output dimension"}
)
decoder_input_dim: int = field(
default=512, metadata={"help": "decoder input dimension"}
)
decoder_ffn_embed_dim: int = field(
default=2048, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"})
decoder_retention_heads: int = field(
default=2, metadata={"help": "num decoder retention heads"}
)
decoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each decoder block"}
)
share_decoder_input_output_embed: bool = field(
default=False, metadata={"help": "share decoder input and output embeddings"}
)
decoder_learned_pos: bool = field(
default=False,
metadata={"help": "use learned positional embeddings in the decoder"},
)
layernorm_embedding: bool = field(
default=False, metadata={"help": "add layernorm to embedding"}
)
no_scale_embedding: bool = field(
default=False, metadata={"help": "if True, dont scale embeddings"}
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "checkpoint activations at each layer"}
)
offload_activations: bool = field(
default=False,
metadata={"help": "move checkpointed activations to CPU after they are used."},
)
# config for Fully Sharded Data Parallel (FSDP) training
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
"help": (
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
},
)
moe_freq: int = field(
default=0,
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
)
moe_expert_count: int = field(
default=0, metadata={"help": "Number of experts in each MoE Layer"}
)
moe_gating_use_fp32: bool = field(
default=False,
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
)
moe_second_expert_policy: str = field(
default="sampling",
metadata={"help": "policy for second expert, options: all/sampling/random"},
)
moe_normalize_gate_prob_before_dropping: bool = field(
default=False,
metadata={
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
},
)
moe_expert_ffn_dim: Optional[int] = field(
default=None, metadata={"help": "MoE expert FFN dimension"}
)
moe_top1_expert: Optional[bool] = field(
default=False, metadata={"help": "Use top1 gate instead of top2"}
)
moe_eval_capacity_token_fraction: Optional[float] = field(
default=0.25,
metadata={
"help": (
"Default: 0.25, Fraction of tokens as capacity during validation, "
"if set to negative, use same as training. range: (0.0, 1.0]."
)
},
)
moe_normalize_expert_grad: Optional[str] = field(
default="world_size",
metadata={
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
},
)
record_a2a_perf_stats: Optional[bool] = field(
default=False,
metadata={"help": "records all to all perf stats during distributed training"},
)
dummy_a2a: Optional[bool] = field(
default=False,
metadata={
"help": "By passes all to all during distributed training by returning the input buffer as output"
},
)
moe_batch_prioritized_routing: Optional[bool] = field(
default=False,
metadata={
"help": "if true orders token by the gate prob before capacity dropping."
},
)
use_xmoe: Optional[bool] = field(
default=False,
)
chunkwise_recurrent: Optional[bool] = field(
default=False,
)
recurrent_chunk_size: Optional[int] = field(
default=512,
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("common.tpu")
memory_efficient_fp16: bool = II("common.memory_efficient_fp16")
fp16: bool = II("common.fp16")
fp16_no_flatten_grads: bool = II("common.fp16_no_flatten_grads")
ddp_backend: str = II("distributed_training.ddp_backend")
world_size: int = II("distributed_training.distributed_world_size")
distributed_rank: int = II("distributed_training.distributed_rank")
ddp_rank: int = II("distributed_training.distributed_rank")
deepnorm: Optional[bool] = field(
default=False,
)
subln: Optional[bool] = field(
default=False,
)
@register_model("retnet", dataclass=LanguageConfig)
class RetNetLanguageModel(FairseqLanguageModel):
def __init__(self, args, decoder):
self.args = args
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_embed_dim
)
if args.share_decoder_input_output_embed:
output_projection = torch.nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
output_projection.weight = embed_tokens.weight
else:
output_projection = torch.nn.Linear(
args.decoder_embed_dim, len(task.dictionary), bias=False
)
torch.nn.init.normal_(
output_projection.weight, mean=0, std=args.decoder_embed_dim**-0.5
)
if getattr(args, "moe_freq", 0) > 0 and (
getattr(args, "fp16", False)
and not getattr(args, "memory_efficient_fp16", False)
and getattr(args, "ddp_backend", None) != "fully_sharded"
):
assert (
args.fp16_no_flatten_grads
), "If training moe models, set --fp16-no-flatten-grads to calculate correct gradnorm"
args.ddp_rank = distributed_utils.get_data_parallel_rank()
config = RetNetConfig()
config.override(args)
decoder = LMDecoder(
config,
embed_tokens,
output_projection,
dictionary=task.dictionary,
)
return cls(args, decoder)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return Embedding(len(dictionary), embed_dim, dictionary.pad())
class LMDecoder(RetNetDecoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_target_positions
def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
@register_model_architecture("retnet", "retnet_base")
def retnet_base_architecture(args):
# backward compatibility for older model checkpoints
if hasattr(args, "no_tie_adaptive_proj"):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args.no_decoder_final_norm = True # old models always set this to True
if args.no_tie_adaptive_proj is False:
args.tie_adaptive_proj = True
if hasattr(args, "decoder_final_norm"):
args.no_decoder_final_norm = not args.decoder_final_norm
args.dropout = getattr(args, "dropout", 0.0)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 2)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.base_layers = getattr(args, "base_layers", 0)
args.base_sublayers = getattr(args, "base_sublayers", 1)
args.base_shuffle = getattr(args, "base_shuffle", False)
args.add_bos_token = getattr(args, "add_bos_token", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.character_embeddings = getattr(args, "character_embeddings", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.chunkwise_recurrent = getattr(args, "chunkwise_recurrent", False)
args.recurrent_chunk_size = getattr(args, "recurrent_chunk_size", 512)
# Model training is not stable without this
args.decoder_normalize_before = True
args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4)
args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
@register_model_architecture("retnet", "retnet_medium")
def retnet_medium(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 16)
args.decoder_retention_heads = getattr(args, "decoder_retention_heads", 4)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_xl")
def retnet_xl(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_layers = getattr(args, "decoder_layers", 24)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_3b")
def retnet_3b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2560)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 10)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_7b")
def retnet_7b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 4096)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_layers = getattr(args, "decoder_layers", 32)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_13b")
def retnet_13b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 5120)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 10240)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20)
args.decoder_layers = getattr(args, "decoder_layers", 40)
retnet_base_architecture(args)
@register_model_architecture("retnet", "retnet_65b")
def retnet_65b(args):
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 8192)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 16384)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32)
args.decoder_layers = getattr(args, "decoder_layers", 64)
retnet_base_architecture(args)