rollback variant name
This commit is contained in:
parent
7f07609361
commit
fd8234c2ac
|
@ -70,7 +70,7 @@ class LanguageConfig(FairseqDataclass):
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "use learned positional embeddings in the decoder"},
|
metadata={"help": "use learned positional embeddings in the decoder"},
|
||||||
)
|
)
|
||||||
norm_embedding: bool = field(
|
layernorm_embedding: bool = field(
|
||||||
default=False, metadata={"help": "add norm to embedding"}
|
default=False, metadata={"help": "add norm to embedding"}
|
||||||
)
|
)
|
||||||
no_scale_embedding: bool = field(
|
no_scale_embedding: bool = field(
|
||||||
|
@ -325,7 +325,7 @@ def retnet_base_architecture(args):
|
||||||
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False)
|
||||||
|
|
||||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||||
args.norm_embedding = getattr(args, "norm_embedding", False)
|
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||||
args.offload_activations = getattr(args, "offload_activations", False)
|
args.offload_activations = getattr(args, "offload_activations", False)
|
||||||
if args.offload_activations:
|
if args.offload_activations:
|
||||||
|
|
|
@ -222,7 +222,7 @@ class RetNetConfig(object):
|
||||||
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
self.drop_path_rate = kwargs.pop("drop_path_rate", 0.0)
|
||||||
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
self.activation_dropout = kwargs.pop("activation_dropout", 0.0)
|
||||||
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
self.no_scale_embedding = kwargs.pop("no_scale_embedding", True)
|
||||||
self.norm_embedding = kwargs.pop("norm_embedding", False)
|
self.layernorm_embedding = kwargs.pop("layernorm_embedding", False)
|
||||||
self.moe_freq = kwargs.pop("moe_freq", 0)
|
self.moe_freq = kwargs.pop("moe_freq", 0)
|
||||||
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
self.moe_top1_expert = kwargs.pop("moe_top1_expert", False)
|
||||||
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
self.moe_expert_count = kwargs.pop("moe_expert_count", 0)
|
||||||
|
@ -245,7 +245,7 @@ class RetNetConfig(object):
|
||||||
)
|
)
|
||||||
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
self.max_target_positions = kwargs.pop("max_target_positions", 1024)
|
||||||
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
self.no_output_layer = kwargs.pop("no_output_layer", False)
|
||||||
self.norm_eps = kwargs.pop("norm_eps", 1e-6)
|
self.layernorm_eps = kwargs.pop("layernorm_eps", 1e-6)
|
||||||
# Blockwise
|
# Blockwise
|
||||||
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
|
self.chunkwise_recurrent = kwargs.pop("chunkwise_recurrent", False)
|
||||||
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
|
self.recurrent_chunk_size = kwargs.pop("recurrent_chunk_size", 512)
|
||||||
|
|
|
@ -11,7 +11,8 @@ from fairscale.nn import checkpoint_wrapper, wrap
|
||||||
|
|
||||||
from torchscale.architecture.utils import init_bert_params
|
from torchscale.architecture.utils import init_bert_params
|
||||||
from torchscale.component.droppath import DropPath
|
from torchscale.component.droppath import DropPath
|
||||||
from torchscale.component.gate_linear_unit import GLU, make_experts
|
from torchscale.component.feedforward_network import make_experts
|
||||||
|
from torchscale.component.gate_linear_unit import GLU
|
||||||
from torchscale.component.multiscale_retention import MultiScaleRetention
|
from torchscale.component.multiscale_retention import MultiScaleRetention
|
||||||
from torchscale.component.xmoe.moe_layer import MOELayer
|
from torchscale.component.xmoe.moe_layer import MOELayer
|
||||||
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
from torchscale.component.xmoe.routing import Top1Gate, Top2Gate
|
||||||
|
@ -88,7 +89,7 @@ class DecoderLayer(nn.Module):
|
||||||
|
|
||||||
self.normalize_before = args.decoder_normalize_before
|
self.normalize_before = args.decoder_normalize_before
|
||||||
|
|
||||||
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.norm_eps)
|
self.retention_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||||
|
|
||||||
self.is_moe_layer = is_moe_layer
|
self.is_moe_layer = is_moe_layer
|
||||||
self.ffn_dim = args.decoder_ffn_embed_dim
|
self.ffn_dim = args.decoder_ffn_embed_dim
|
||||||
|
@ -120,7 +121,7 @@ class DecoderLayer(nn.Module):
|
||||||
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
experts = make_experts(args, self.embed_dim, self.ffn_dim)
|
||||||
self.moe_layer = MOELayer(gate, experts, args)
|
self.moe_layer = MOELayer(gate, experts, args)
|
||||||
|
|
||||||
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.norm_eps)
|
self.final_layer_norm = RMSNorm(self.embed_dim, eps=args.layernorm_eps)
|
||||||
|
|
||||||
if args.deepnorm:
|
if args.deepnorm:
|
||||||
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
self.alpha = math.pow(2.0 * args.decoder_layers, 0.25)
|
||||||
|
@ -220,10 +221,10 @@ class RetNetDecoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.output_projection = output_projection
|
self.output_projection = output_projection
|
||||||
|
|
||||||
if args.norm_embedding:
|
if args.layernorm_embedding:
|
||||||
self.norm_embedding = RMSNorm(embed_dim, eps=args.norm_eps)
|
self.layernorm_embedding = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
||||||
else:
|
else:
|
||||||
self.norm_embedding = None
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
@ -241,7 +242,7 @@ class RetNetDecoder(nn.Module):
|
||||||
self.num_layers = len(self.layers)
|
self.num_layers = len(self.layers)
|
||||||
|
|
||||||
if args.decoder_normalize_before:
|
if args.decoder_normalize_before:
|
||||||
self.layer_norm = RMSNorm(embed_dim, eps=args.norm_eps)
|
self.layer_norm = RMSNorm(embed_dim, eps=args.layernorm_eps)
|
||||||
else:
|
else:
|
||||||
self.layer_norm = None
|
self.layer_norm = None
|
||||||
|
|
||||||
|
@ -309,8 +310,8 @@ class RetNetDecoder(nn.Module):
|
||||||
|
|
||||||
x = embed = self.embed_scale * token_embedding
|
x = embed = self.embed_scale * token_embedding
|
||||||
|
|
||||||
if self.norm_embedding is not None:
|
if self.layernorm_embedding is not None:
|
||||||
x = self.norm_embedding(x)
|
x = self.layernorm_embedding(x)
|
||||||
|
|
||||||
x = self.dropout_module(x)
|
x = self.dropout_module(x)
|
||||||
|
|
||||||
|
@ -345,7 +346,7 @@ class RetNetDecoder(nn.Module):
|
||||||
slen = prev_output_tokens.size(1)
|
slen = prev_output_tokens.size(1)
|
||||||
# relative position
|
# relative position
|
||||||
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
retention_rel_pos = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=self.chunkwise_recurrent)
|
||||||
|
retention_rel_pos_no_block = self.retnet_rel_pos(slen, incremental_state is not None and not is_first_step, chunkwise_recurrent=False)
|
||||||
# decoder layers
|
# decoder layers
|
||||||
inner_states = [x]
|
inner_states = [x]
|
||||||
|
|
||||||
|
@ -360,12 +361,20 @@ class RetNetDecoder(nn.Module):
|
||||||
if idx not in incremental_state:
|
if idx not in incremental_state:
|
||||||
incremental_state[idx] = {}
|
incremental_state[idx] = {}
|
||||||
|
|
||||||
|
x_no_block, _ = layer(
|
||||||
|
x,
|
||||||
|
incremental_state[idx] if incremental_state is not None else None,
|
||||||
|
retention_rel_pos=retention_rel_pos_no_block,
|
||||||
|
chunkwise_recurrent=False,
|
||||||
|
)
|
||||||
x, l_aux_i = layer(
|
x, l_aux_i = layer(
|
||||||
x,
|
x,
|
||||||
incremental_state[idx] if incremental_state is not None else None,
|
incremental_state[idx] if incremental_state is not None else None,
|
||||||
retention_rel_pos=retention_rel_pos,
|
retention_rel_pos=retention_rel_pos,
|
||||||
chunkwise_recurrent=self.chunkwise_recurrent,
|
chunkwise_recurrent=self.chunkwise_recurrent,
|
||||||
)
|
)
|
||||||
|
print(x[0], x_no_block[0], (x - x_no_block).abs().max(), (x - x_no_block).abs().sum())
|
||||||
|
exit()
|
||||||
l_aux.append(l_aux_i)
|
l_aux.append(l_aux_i)
|
||||||
inner_states.append(x)
|
inner_states.append(x)
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,8 @@ def get_activation_fn(activation):
|
||||||
return F.relu
|
return F.relu
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
return F.gelu
|
return F.gelu
|
||||||
|
elif activation == "swish":
|
||||||
|
return F.silu
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -5,96 +5,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .xmoe.global_groups import get_moe_group
|
from .feedforward_network import get_activation_fn
|
||||||
|
|
||||||
|
|
||||||
class set_torch_seed(object):
|
|
||||||
def __init__(self, seed):
|
|
||||||
assert isinstance(seed, int)
|
|
||||||
self.rng_state = self.get_rng_state()
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
def get_rng_state(self):
|
|
||||||
state = {"torch_rng_state": torch.get_rng_state()}
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
state["cuda_rng_state"] = torch.cuda.get_rng_state()
|
|
||||||
return state
|
|
||||||
|
|
||||||
def set_rng_state(self, state):
|
|
||||||
torch.set_rng_state(state["torch_rng_state"])
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.set_rng_state(state["cuda_rng_state"])
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
self.set_rng_state(self.rng_state)
|
|
||||||
|
|
||||||
|
|
||||||
def make_experts(args, embed_dim, expert_ffn_dim):
|
|
||||||
world_size = (
|
|
||||||
1
|
|
||||||
if not torch.distributed.is_initialized()
|
|
||||||
else torch.distributed.get_world_size()
|
|
||||||
)
|
|
||||||
expert_list = []
|
|
||||||
ddp_rank = args.ddp_rank
|
|
||||||
start_seed = torch.randint(1000000, (1,)).item()
|
|
||||||
# at least as many experts than gpus
|
|
||||||
if args.moe_expert_count >= world_size:
|
|
||||||
assert (
|
|
||||||
args.moe_expert_count % world_size == 0
|
|
||||||
), f"{args.moe_expert_count}, {world_size}"
|
|
||||||
local_moe_expert_count = args.moe_expert_count // world_size
|
|
||||||
for i in range(local_moe_expert_count):
|
|
||||||
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i):
|
|
||||||
expert_list.append(
|
|
||||||
GLU(
|
|
||||||
embed_dim,
|
|
||||||
expert_ffn_dim,
|
|
||||||
args.activation_fn,
|
|
||||||
args.dropout,
|
|
||||||
args.activation_dropout,
|
|
||||||
args.layernorm_eps,
|
|
||||||
args.subln,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
world_size % args.moe_expert_count == 0
|
|
||||||
), f"{world_size}, {args.moe_expert_count}"
|
|
||||||
|
|
||||||
moe_idx, _ = get_moe_group(args.moe_expert_count)
|
|
||||||
|
|
||||||
with set_torch_seed(start_seed + moe_idx):
|
|
||||||
expert_list.append(
|
|
||||||
GLU(
|
|
||||||
embed_dim,
|
|
||||||
expert_ffn_dim,
|
|
||||||
args.activation_fn,
|
|
||||||
args.dropout,
|
|
||||||
args.activation_dropout,
|
|
||||||
args.layernorm_eps,
|
|
||||||
args.subln,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
experts = nn.ModuleList(expert_list)
|
|
||||||
return experts
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation_fn(activation):
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
elif activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
elif activation == "swish":
|
|
||||||
return F.silu
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class GLU(nn.Module):
|
class GLU(nn.Module):
|
||||||
|
@ -118,6 +29,7 @@ class GLU(nn.Module):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.fc1.reset_parameters()
|
self.fc1.reset_parameters()
|
||||||
self.fc2.reset_parameters()
|
self.fc2.reset_parameters()
|
||||||
|
self.gate.reset_parameters()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_shape = x.shape
|
x_shape = x.shape
|
||||||
|
|
Loading…
Reference in New Issue
Block a user