118 lines
5.0 KiB
Python
118 lines
5.0 KiB
Python
from dataclasses import dataclass
|
|
import json
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
|
|
def load_config_from_json(config_file):
|
|
with open(config_file, 'r') as f:
|
|
config = json.loads(f.read())
|
|
config = RetNetConfig.from_dict(config)
|
|
return config
|
|
|
|
|
|
@dataclass
|
|
class RetNetConfig(PretrainedConfig):
|
|
model_type = "retnet"
|
|
initializer_range: float = 0.02
|
|
activation_fn: str = "gelu"
|
|
dropout: float = 0.0 # dropout probability
|
|
activation_dropout: float = 0.0 # dropout probability after activation in FFN.
|
|
drop_path_rate: float = 0.0
|
|
decoder_embed_dim: int = 768 # decoder embedding dimension
|
|
decoder_value_embed_dim: int = 1280 # decoder value embedding dimension
|
|
decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN
|
|
decoder_layers: int = 12 # num decoder layers
|
|
decoder_retention_heads: int = 3 # num decoder retention heads
|
|
decoder_normalize_before: bool = True # apply layernorm before each decoder block
|
|
layernorm_embedding: bool = False # add layernorm to embedding
|
|
no_scale_embedding: bool = True # if True, dont scale embeddings
|
|
recurrent_chunk_size: int = 512
|
|
use_lm_decay: bool = False
|
|
use_glu: bool = True # use GLU instead of FFN
|
|
z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4
|
|
deepnorm: bool = False
|
|
subln: bool = True
|
|
use_ffn_rms_norm: bool = False
|
|
layernorm_eps: float = 1e-6
|
|
tie_word_embeddings: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size: int = 50257,
|
|
initializer_range: float = 0.02,
|
|
is_decoder: bool = True,
|
|
pad_token_id: int = 0,
|
|
eos_token_id: int = 0,
|
|
output_retentions: bool = False,
|
|
use_cache: bool = True,
|
|
forward_impl: str = 'parallel',
|
|
activation_fn: str = "gelu",
|
|
dropout: float = 0.0, # dropout probability
|
|
activation_dropout: float = 0.0, # dropout probability after activation in FFN.
|
|
drop_path_rate: float = 0.0,
|
|
decoder_embed_dim: int = 768, # decoder embedding dimension
|
|
decoder_value_embed_dim: int = 1280, # decoder value embedding dimension
|
|
decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN
|
|
decoder_layers: int = 12, # num decoder layers
|
|
decoder_retention_heads: int = 3, # num decoder retention heads
|
|
decoder_normalize_before: bool = True, # apply layernorm before each decoder block
|
|
layernorm_embedding: bool = False, # add layernorm to embedding
|
|
no_scale_embedding: bool = True, # if True, dont scale embeddings
|
|
recurrent_chunk_size: int = 512,
|
|
use_glu: bool = True, # use GLU instead of FFN
|
|
z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
|
|
use_lm_decay: bool = False,
|
|
deepnorm: bool = True,
|
|
subln: bool = True,
|
|
use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
|
|
layernorm_eps: float = 1e-6,
|
|
tie_word_embeddings: bool = False,
|
|
**kwargs):
|
|
self.vocab_size = vocab_size
|
|
self.initializer_range = initializer_range
|
|
self.output_retentions = output_retentions
|
|
self.use_lm_decay = use_lm_decay
|
|
self.use_glu = use_glu
|
|
self.z_loss_coeff = z_loss_coeff
|
|
# size related
|
|
self.decoder_embed_dim = decoder_embed_dim
|
|
self.decoder_value_embed_dim = decoder_value_embed_dim
|
|
self.decoder_retention_heads = decoder_retention_heads
|
|
self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
|
|
self.decoder_layers = decoder_layers
|
|
# normalization related
|
|
self.decoder_normalize_before = decoder_normalize_before
|
|
self.activation_fn = activation_fn
|
|
self.dropout = dropout
|
|
self.drop_path_rate = drop_path_rate
|
|
self.activation_dropout = activation_dropout
|
|
self.no_scale_embedding = no_scale_embedding
|
|
self.layernorm_embedding = layernorm_embedding
|
|
self.deepnorm = deepnorm
|
|
self.subln = subln
|
|
self.use_ffn_rms_norm = use_ffn_rms_norm
|
|
self.layernorm_eps = layernorm_eps
|
|
# Blockwise
|
|
self.recurrent_chunk_size = recurrent_chunk_size
|
|
self.forward_impl = forward_impl
|
|
|
|
if self.deepnorm:
|
|
self.decoder_normalize_before = False
|
|
self.subln = False
|
|
if self.subln:
|
|
self.decoder_normalize_before = True
|
|
self.deepnorm = False
|
|
|
|
super().__init__(is_decoder=is_decoder,
|
|
pad_token_id=pad_token_id,
|
|
eos_token_id=eos_token_id,
|
|
use_cache=use_cache,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs)
|
|
|
|
def override(self, args):
|
|
for hp in self.__dict__.keys():
|
|
if getattr(args, hp, None) is not None:
|
|
self.__dict__[hp] = getattr(args, hp, None)
|