added FP8 support through NVIDIA/TransformerEngine, added RetNet_HF through syncdoth/RetNet (as an alternative to branch away from torchscale)
This commit is contained in:
parent
7075c2a5f0
commit
9d97eb5104
@ -176,6 +176,11 @@ class Model:
|
|||||||
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
|
||||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||||
|
|
||||||
|
@property
|
||||||
|
# required for fp8 as the lengths needs to be divisible by 8
|
||||||
|
def input_alignment(self):
|
||||||
|
return 8 if cfg.fp8.enabled else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
name = [ self.name ]
|
name = [ self.name ]
|
||||||
@ -503,6 +508,10 @@ class Trainer:
|
|||||||
return torch.float16
|
return torch.float16
|
||||||
if self.weight_dtype == "bfloat16":
|
if self.weight_dtype == "bfloat16":
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
|
if self.weight_dtype == "float8_e5m2":
|
||||||
|
return torch.float8_e5m2
|
||||||
|
if self.weight_dtype == "float8_e4m3fn":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
@ -527,6 +536,10 @@ class Inference:
|
|||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if self.weight_dtype == "int8":
|
if self.weight_dtype == "int8":
|
||||||
return torch.int8
|
return torch.int8
|
||||||
|
if self.weight_dtype == "float8_e5m2":
|
||||||
|
return torch.float8_e5m2
|
||||||
|
if self.weight_dtype == "float8_e4m3fn":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
@ -540,6 +553,11 @@ class BitsAndBytes:
|
|||||||
|
|
||||||
bitnet: bool = False
|
bitnet: bool = False
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class FP8:
|
||||||
|
enabled: bool = False
|
||||||
|
backend: str = "te"
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Config(_Config):
|
class Config(_Config):
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
@ -553,6 +571,8 @@ class Config(_Config):
|
|||||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||||
inference: Inference = field(default_factory=lambda: Inference)
|
inference: Inference = field(default_factory=lambda: Inference)
|
||||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||||
|
|
||||||
|
fp8: FP8 = field(default_factory=lambda: FP8)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_rate(self):
|
def sample_rate(self):
|
||||||
@ -620,6 +640,7 @@ try:
|
|||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from typing import Any, Protocol
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
from .base import TrainFeeder
|
from .base import TrainFeeder
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -222,10 +223,11 @@ class Engine():
|
|||||||
return self._global_grad_norm
|
return self._global_grad_norm
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with ml.autocast():
|
||||||
self.forward(*args, **kwargs)
|
self.forward(*args, **kwargs)
|
||||||
losses = self.gather_attribute("loss")
|
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
losses = self.gather_attribute("loss")
|
||||||
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distr
|
|||||||
from deepspeed.accelerator import get_accelerator
|
from deepspeed.accelerator import get_accelerator
|
||||||
|
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized
|
from ..utils.distributed import init_distributed, distributed_initialized
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
|
||||||
init_distributed(init_deepspeed_dist)
|
init_distributed(init_deepspeed_dist)
|
||||||
@ -106,10 +107,11 @@ class Engine(DeepSpeedEngine):
|
|||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with ml.autocast():
|
||||||
self.forward(*args, **kwargs)
|
self.forward(*args, **kwargs)
|
||||||
losses = self.gather_attribute("loss")
|
|
||||||
loss = torch.stack([*losses.values()]).sum()
|
losses = self.gather_attribute("loss")
|
||||||
|
loss = torch.stack([*losses.values()]).sum()
|
||||||
|
|
||||||
stats = {}
|
stats = {}
|
||||||
stats |= {k: v.item() for k, v in losses.items()}
|
stats |= {k: v.item() for k, v in losses.items()}
|
||||||
|
|||||||
0
vall_e/ext/__init__.py
Normal file
0
vall_e/ext/__init__.py
Normal file
3
vall_e/ext/retnet_hf/__init__.py
Normal file
3
vall_e/ext/retnet_hf/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# from https://github.com/syncdoth/RetNet/
|
||||||
|
|
||||||
|
# there is no proper build system and I can't be assed to fork it or make it a submodule that plays nicely with python's import system
|
||||||
117
vall_e/ext/retnet_hf/configuration_retnet.py
Normal file
117
vall_e/ext/retnet_hf/configuration_retnet.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
def load_config_from_json(config_file):
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config = RetNetConfig.from_dict(config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetNetConfig(PretrainedConfig):
|
||||||
|
model_type = "retnet"
|
||||||
|
initializer_range: float = 0.02
|
||||||
|
activation_fn: str = "gelu"
|
||||||
|
dropout: float = 0.0 # dropout probability
|
||||||
|
activation_dropout: float = 0.0 # dropout probability after activation in FFN.
|
||||||
|
drop_path_rate: float = 0.0
|
||||||
|
decoder_embed_dim: int = 768 # decoder embedding dimension
|
||||||
|
decoder_value_embed_dim: int = 1280 # decoder value embedding dimension
|
||||||
|
decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN
|
||||||
|
decoder_layers: int = 12 # num decoder layers
|
||||||
|
decoder_retention_heads: int = 3 # num decoder retention heads
|
||||||
|
decoder_normalize_before: bool = True # apply layernorm before each decoder block
|
||||||
|
layernorm_embedding: bool = False # add layernorm to embedding
|
||||||
|
no_scale_embedding: bool = True # if True, dont scale embeddings
|
||||||
|
recurrent_chunk_size: int = 512
|
||||||
|
use_lm_decay: bool = False
|
||||||
|
use_glu: bool = True # use GLU instead of FFN
|
||||||
|
z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4
|
||||||
|
deepnorm: bool = False
|
||||||
|
subln: bool = True
|
||||||
|
use_ffn_rms_norm: bool = False
|
||||||
|
layernorm_eps: float = 1e-6
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 50257,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
is_decoder: bool = True,
|
||||||
|
pad_token_id: int = 0,
|
||||||
|
eos_token_id: int = 0,
|
||||||
|
output_retentions: bool = False,
|
||||||
|
use_cache: bool = True,
|
||||||
|
forward_impl: str = 'parallel',
|
||||||
|
activation_fn: str = "gelu",
|
||||||
|
dropout: float = 0.0, # dropout probability
|
||||||
|
activation_dropout: float = 0.0, # dropout probability after activation in FFN.
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
decoder_embed_dim: int = 768, # decoder embedding dimension
|
||||||
|
decoder_value_embed_dim: int = 1280, # decoder value embedding dimension
|
||||||
|
decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN
|
||||||
|
decoder_layers: int = 12, # num decoder layers
|
||||||
|
decoder_retention_heads: int = 3, # num decoder retention heads
|
||||||
|
decoder_normalize_before: bool = True, # apply layernorm before each decoder block
|
||||||
|
layernorm_embedding: bool = False, # add layernorm to embedding
|
||||||
|
no_scale_embedding: bool = True, # if True, dont scale embeddings
|
||||||
|
recurrent_chunk_size: int = 512,
|
||||||
|
use_glu: bool = True, # use GLU instead of FFN
|
||||||
|
z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
|
||||||
|
use_lm_decay: bool = False,
|
||||||
|
deepnorm: bool = True,
|
||||||
|
subln: bool = True,
|
||||||
|
use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
|
||||||
|
layernorm_eps: float = 1e-6,
|
||||||
|
tie_word_embeddings: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.output_retentions = output_retentions
|
||||||
|
self.use_lm_decay = use_lm_decay
|
||||||
|
self.use_glu = use_glu
|
||||||
|
self.z_loss_coeff = z_loss_coeff
|
||||||
|
# size related
|
||||||
|
self.decoder_embed_dim = decoder_embed_dim
|
||||||
|
self.decoder_value_embed_dim = decoder_value_embed_dim
|
||||||
|
self.decoder_retention_heads = decoder_retention_heads
|
||||||
|
self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
# normalization related
|
||||||
|
self.decoder_normalize_before = decoder_normalize_before
|
||||||
|
self.activation_fn = activation_fn
|
||||||
|
self.dropout = dropout
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
self.no_scale_embedding = no_scale_embedding
|
||||||
|
self.layernorm_embedding = layernorm_embedding
|
||||||
|
self.deepnorm = deepnorm
|
||||||
|
self.subln = subln
|
||||||
|
self.use_ffn_rms_norm = use_ffn_rms_norm
|
||||||
|
self.layernorm_eps = layernorm_eps
|
||||||
|
# Blockwise
|
||||||
|
self.recurrent_chunk_size = recurrent_chunk_size
|
||||||
|
self.forward_impl = forward_impl
|
||||||
|
|
||||||
|
if self.deepnorm:
|
||||||
|
self.decoder_normalize_before = False
|
||||||
|
self.subln = False
|
||||||
|
if self.subln:
|
||||||
|
self.decoder_normalize_before = True
|
||||||
|
self.deepnorm = False
|
||||||
|
|
||||||
|
super().__init__(is_decoder=is_decoder,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
use_cache=use_cache,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def override(self, args):
|
||||||
|
for hp in self.__dict__.keys():
|
||||||
|
if getattr(args, hp, None) is not None:
|
||||||
|
self.__dict__[hp] = getattr(args, hp, None)
|
||||||
1455
vall_e/ext/retnet_hf/modeling_retnet.py
Normal file
1455
vall_e/ext/retnet_hf/modeling_retnet.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -20,7 +20,9 @@ def get_model(cfg, training=True):
|
|||||||
n_layers=cfg.layers,
|
n_layers=cfg.layers,
|
||||||
n_experts=cfg.experts,
|
n_experts=cfg.experts,
|
||||||
|
|
||||||
training=training,
|
l_padding = cfg.input_alignment,
|
||||||
|
|
||||||
|
training = training,
|
||||||
config = cfg,
|
config = cfg,
|
||||||
)
|
)
|
||||||
model._cfg = cfg
|
model._cfg = cfg
|
||||||
|
|||||||
@ -300,7 +300,7 @@ class AR_NAR(Base):
|
|||||||
|
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
cfg.trainer.backend = "local"
|
#cfg.trainer.backend = "local"
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
@ -317,7 +317,7 @@ def example_usage():
|
|||||||
def tokenize(content, lang_marker="en"):
|
def tokenize(content, lang_marker="en"):
|
||||||
split = content.split(" ")
|
split = content.split(" ")
|
||||||
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
|
||||||
return torch.tensor([*map(symmap.get, phones)]).to()
|
return torch.tensor([*map(symmap.get, phones)])
|
||||||
|
|
||||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
|
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
|
||||||
|
|
||||||
@ -344,6 +344,8 @@ def example_usage():
|
|||||||
'n_heads': 16, # 4, # 16, # 24
|
'n_heads': 16, # 4, # 16, # 24
|
||||||
'n_layers': 12, # 32
|
'n_layers': 12, # 32
|
||||||
'n_experts': 1,
|
'n_experts': 1,
|
||||||
|
|
||||||
|
'l_padding': 8,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@ -366,6 +368,7 @@ def example_usage():
|
|||||||
steps = 500
|
steps = 500
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
# copy embeddings if requested
|
# copy embeddings if requested
|
||||||
@ -392,15 +395,15 @@ def example_usage():
|
|||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
engine._frozen_params.add(param)
|
engine._frozen_params.add(param)
|
||||||
|
|
||||||
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
# if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
|
||||||
model.model = ml.replace_linear( model.model )
|
model.model = ml.replace_linear( model.model )
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
'module': model.state_dict()
|
'module': model.state_dict()
|
||||||
}, "./data/test.pth" )
|
}, "./data/test.pth" )
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sample( name, steps=600 ):
|
def sample( name, steps=600 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
|
|||||||
@ -29,6 +29,14 @@ except Exception as e:
|
|||||||
print("Error importing `retnet` arch:", e)
|
print("Error importing `retnet` arch:", e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from .retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
except Exception as e:
|
||||||
|
print("Error importing `retnet-hf` arch:", e)
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import LlamaModel, LlamaConfig
|
from transformers import LlamaModel, LlamaConfig
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -44,6 +52,7 @@ except Exception as e:
|
|||||||
try:
|
try:
|
||||||
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
from bitnet.bit_transformer import Transformer as BitNetTransformerBlock, RMSNorm as BitNetRMSNorm
|
||||||
|
|
||||||
|
# override because bitnet's BitNetTransformer includes an embedding input / classifier output layers inside of it, which isn't favorable
|
||||||
class BitNetTransformer(nn.Module):
|
class BitNetTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -159,7 +168,6 @@ class Embedding(nn.Embedding):
|
|||||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||||
|
|
||||||
class MultiEmbedding(nn.Module):
|
class MultiEmbedding(nn.Module):
|
||||||
@ -308,7 +316,9 @@ class Base(nn.Module):
|
|||||||
n_layers: int = 12,
|
n_layers: int = 12,
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
|
|
||||||
n_experts: int=1,
|
n_experts: int = 1,
|
||||||
|
|
||||||
|
l_padding: int = 0,
|
||||||
|
|
||||||
training = True,
|
training = True,
|
||||||
config = None,
|
config = None,
|
||||||
@ -323,6 +333,8 @@ class Base(nn.Module):
|
|||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
self.n_experts = n_experts
|
self.n_experts = n_experts
|
||||||
|
|
||||||
|
self.l_padding = l_padding
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
||||||
@ -460,6 +472,27 @@ class Base(nn.Module):
|
|||||||
))
|
))
|
||||||
|
|
||||||
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
self.model = RetNetDecoder(RetNetConfig(**kwargs))
|
||||||
|
elif self.arch_type == "retnet-hf":
|
||||||
|
kwargs = dict(
|
||||||
|
vocab_size=n_resp_tokens,
|
||||||
|
decoder_embed_dim=d_model,
|
||||||
|
decoder_value_embed_dim =d_model * 2,
|
||||||
|
decoder_retention_heads=n_heads,
|
||||||
|
decoder_ffn_embed_dim=d_model * 4,
|
||||||
|
decoder_layers=n_layers,
|
||||||
|
dropout=p_dropout if training else 0.0,
|
||||||
|
checkpoint_activations=self.activation_checkpointing,
|
||||||
|
activation_fn="gelu",
|
||||||
|
use_glu=False, # self.version >= 3,
|
||||||
|
|
||||||
|
recurrent_chunk_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
|
decoder_normalize_before=True,
|
||||||
|
|
||||||
|
deepnorm=False,
|
||||||
|
subln=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs))
|
||||||
elif self.arch_type == "bitnet":
|
elif self.arch_type == "bitnet":
|
||||||
self.model = BitNetTransformer(
|
self.model = BitNetTransformer(
|
||||||
num_tokens=n_resp_tokens,
|
num_tokens=n_resp_tokens,
|
||||||
@ -514,19 +547,50 @@ class Base(nn.Module):
|
|||||||
sep=self.sep,
|
sep=self.sep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
aux_loss = None
|
aux_loss = None
|
||||||
|
|
||||||
device = x.device
|
device = x.device
|
||||||
|
|
||||||
|
# pad our input and mask, but retain the original length by doing it after
|
||||||
|
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
||||||
|
# pad input
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[1] = self.l_padding - shape[1] % self.l_padding
|
||||||
|
|
||||||
|
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||||
|
x = torch.cat([x, padding], dim=1)
|
||||||
|
|
||||||
|
# pad mask
|
||||||
|
shape[2] = 1
|
||||||
|
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||||
|
m = torch.cat([m, padding], dim=1)
|
||||||
|
|
||||||
if state is not None and self.arch_type == "retnet":
|
if state is not None and self.arch_type == "retnet":
|
||||||
# prefill
|
# prefill
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
prefill_size = x.shape[1]
|
prefill_size = x.shape[1]
|
||||||
|
|
||||||
# run the initial prompt to fill the KV cache
|
# run the initial prompt to fill the KV cache
|
||||||
for n in range(prefill_size):
|
if self.arch_type == "retnet":
|
||||||
xi = x[:, n, :].unsqueeze(1)
|
for n in range(prefill_size):
|
||||||
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
xi = x[:, n, :].unsqueeze(1)
|
||||||
|
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||||
|
elif self.arch_type == "retnet-hf":
|
||||||
|
for n in range(prefill_size):
|
||||||
|
xi = x[:, n, :].unsqueeze(1)
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
#attention_mask=m,
|
||||||
|
inputs_embeds=x,
|
||||||
|
past_key_values=state[-1],
|
||||||
|
use_cache=state is not None,
|
||||||
|
# return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.model(**kwargs)
|
||||||
|
state.append(out.past_key_values)
|
||||||
|
|
||||||
# grab last token(s)
|
# grab last token(s)
|
||||||
x = x[:, -1, :].unsqueeze(1)
|
x = x[:, -1, :].unsqueeze(1)
|
||||||
@ -566,6 +630,21 @@ class Base(nn.Module):
|
|||||||
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
if _ is not None and "l_aux" in _ and self.n_experts > 1:
|
||||||
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
aux_loss = torch.sum(torch.stack([ t for t in _["l_aux"] if t is not None])) * 0.001
|
||||||
|
elif self.arch_type == "retnet-hf":
|
||||||
|
kwargs = dict(
|
||||||
|
#attention_mask=m,
|
||||||
|
inputs_embeds=x,
|
||||||
|
past_key_values=state,
|
||||||
|
use_cache=False, #state is not None,
|
||||||
|
# return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
t = self.model(**kwargs)
|
||||||
|
|
||||||
|
x = t[0]
|
||||||
|
|
||||||
|
if state is not None:
|
||||||
|
state = t[1]
|
||||||
elif self.arch_type == "bitnet":
|
elif self.arch_type == "bitnet":
|
||||||
x = self.model(x)
|
x = self.model(x)
|
||||||
# output projection layer with masking
|
# output projection layer with masking
|
||||||
|
|||||||
@ -1,3 +1,46 @@
|
|||||||
|
# https://github.com/microsoft/torchscale
|
||||||
|
|
||||||
from torchscale.architecture.config import RetNetConfig
|
from torchscale.architecture.config import RetNetConfig
|
||||||
from torchscale.architecture.retnet import RetNetDecoder
|
from torchscale.architecture.retnet import RetNetDecoder
|
||||||
# from retnet import RetNet
|
# from retnet import RetNet
|
||||||
|
|
||||||
|
# override MultiScaleRetention's forward because training with te throws an error
|
||||||
|
from torchscale.component.multiscale_retention import MultiScaleRetention, theta_shift
|
||||||
|
|
||||||
|
def MultiScaleRetention_forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
rel_pos,
|
||||||
|
chunkwise_recurrent=False,
|
||||||
|
incremental_state=None
|
||||||
|
):
|
||||||
|
bsz, tgt_len, _ = x.size()
|
||||||
|
(sin, cos), inner_mask = rel_pos
|
||||||
|
|
||||||
|
q = self.q_proj(x)
|
||||||
|
k = self.k_proj(x) * self.scaling
|
||||||
|
v = self.v_proj(x)
|
||||||
|
g = self.g_proj(x)
|
||||||
|
|
||||||
|
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
||||||
|
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
qr = theta_shift(q, sin, cos)
|
||||||
|
kr = theta_shift(k, sin, cos)
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
|
||||||
|
elif chunkwise_recurrent:
|
||||||
|
output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
|
||||||
|
else:
|
||||||
|
output = self.parallel_forward(qr, kr, v, inner_mask)
|
||||||
|
|
||||||
|
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
|
||||||
|
|
||||||
|
output = self.gate_fn(g) * output
|
||||||
|
|
||||||
|
output = self.out_proj(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
MultiScaleRetention.forward = MultiScaleRetention_forward
|
||||||
199
vall_e/models/retnet_hf.py
Normal file
199
vall_e/models/retnet_hf.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# https://github.com/syncdoth/RetNet/
|
||||||
|
from ..ext.retnet_hf.configuration_retnet import RetNetConfig
|
||||||
|
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder
|
||||||
|
|
||||||
|
# things we're overriding or required to override
|
||||||
|
from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
# required to have compatibile LayerNorm
|
||||||
|
def FeedForwardNetwork_init(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
ffn_dim,
|
||||||
|
activation_fn,
|
||||||
|
dropout,
|
||||||
|
activation_dropout,
|
||||||
|
layernorm_eps,
|
||||||
|
subln=True,
|
||||||
|
use_rms_norm=False,
|
||||||
|
):
|
||||||
|
super(FeedForwardNetwork, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.activation_fn = get_activation_fn(activation=str(activation_fn))
|
||||||
|
self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
|
||||||
|
self.dropout_module = torch.nn.Dropout(dropout)
|
||||||
|
self.fc1 = torch.nn.Linear(self.embed_dim, ffn_dim)
|
||||||
|
self.fc2 = torch.nn.Linear(ffn_dim, self.embed_dim)
|
||||||
|
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None
|
||||||
|
|
||||||
|
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
|
||||||
|
|
||||||
|
# removes embed_tokens
|
||||||
|
def RetNetModel_init(
|
||||||
|
self,
|
||||||
|
config: RetNetConfig,
|
||||||
|
embed_tokens: torch.nn.Embedding = None,
|
||||||
|
tensor_parallel: bool = False,
|
||||||
|
):
|
||||||
|
super(RetNetDecoder, self).__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.dropout_module = torch.nn.Dropout(config.dropout)
|
||||||
|
|
||||||
|
self.embed_dim = config.decoder_embed_dim
|
||||||
|
self.embed_scale = (
|
||||||
|
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if embed_tokens is None:
|
||||||
|
embed_tokens = torch.nn.Embedding(
|
||||||
|
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
self.embed_tokens = None
|
||||||
|
|
||||||
|
if config.layernorm_embedding:
|
||||||
|
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||||
|
else:
|
||||||
|
self.layernorm_embedding = None
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList([])
|
||||||
|
|
||||||
|
for i in range(config.decoder_layers):
|
||||||
|
self.layers.append(
|
||||||
|
RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder_layers = len(self.layers)
|
||||||
|
|
||||||
|
if config.decoder_normalize_before:
|
||||||
|
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||||
|
else:
|
||||||
|
self.layer_norm = None
|
||||||
|
|
||||||
|
self.retnet_rel_pos = RetNetRelPos(config)
|
||||||
|
self.recurrent_chunk_size = config.recurrent_chunk_size
|
||||||
|
|
||||||
|
if config.deepnorm:
|
||||||
|
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
|
||||||
|
for name, p in self.named_parameters():
|
||||||
|
if (
|
||||||
|
"fc1" in name
|
||||||
|
or "fc2" in name
|
||||||
|
or "out_proj" in name
|
||||||
|
or "v_proj" in name
|
||||||
|
):
|
||||||
|
p.data.div_(init_scale)
|
||||||
|
|
||||||
|
if config.subln and not config.use_glu:
|
||||||
|
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
|
||||||
|
for name, p in self.named_parameters():
|
||||||
|
if (
|
||||||
|
"fc1" in name
|
||||||
|
or "fc2" in name
|
||||||
|
or "out_proj" in name
|
||||||
|
or "v_proj" in name
|
||||||
|
):
|
||||||
|
p.data.mul_(init_scale)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = True
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
RetNetDecoder.__init__ = RetNetModel_init
|
||||||
|
|
||||||
|
# restores bias in our FFNs
|
||||||
|
def RetNetDecoderLayer_init(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False):
|
||||||
|
super(RetNetDecoderLayer, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.decoder_embed_dim
|
||||||
|
self.dropout_module = torch.nn.Dropout(config.dropout)
|
||||||
|
|
||||||
|
if config.drop_path_rate > 0:
|
||||||
|
drop_path_prob = np.linspace(
|
||||||
|
0, config.drop_path_rate, config.decoder_layers
|
||||||
|
)[depth]
|
||||||
|
self.drop_path = DropPath(drop_path_prob)
|
||||||
|
else:
|
||||||
|
self.drop_path = None
|
||||||
|
|
||||||
|
self.retention = MultiScaleRetention(
|
||||||
|
config, use_bias=True, tensor_parallel=tensor_parallel
|
||||||
|
)
|
||||||
|
|
||||||
|
self.normalize_before = config.decoder_normalize_before
|
||||||
|
|
||||||
|
self.retention_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||||
|
|
||||||
|
self.ffn_dim = config.decoder_ffn_embed_dim
|
||||||
|
|
||||||
|
self.ffn = self.build_ffn()
|
||||||
|
|
||||||
|
self.final_layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
|
||||||
|
|
||||||
|
if config.deepnorm:
|
||||||
|
self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
|
||||||
|
else:
|
||||||
|
self.alpha = 1.0
|
||||||
|
|
||||||
|
RetNetDecoderLayer.__init__ = RetNetDecoderLayer_init
|
||||||
|
# fixes backwards when using te's autocast
|
||||||
|
def MultiScaleRetention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
rel_pos: Tuple[Tuple[torch.Tensor]],
|
||||||
|
retention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
forward_impl: str = "parallel",
|
||||||
|
output_retentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||||
|
B, T, H = hidden_states.size()
|
||||||
|
(sin, cos), decay_mask = rel_pos
|
||||||
|
# projections
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states) * self.scaling # for scaled dot product
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
g = self.g_proj(hidden_states)
|
||||||
|
# multi-head
|
||||||
|
q, k, v = split_heads((q, k, v), B, T, self.num_heads)
|
||||||
|
|
||||||
|
# rotate
|
||||||
|
# NOTE: theta_shift has bug with mps device.
|
||||||
|
qr = theta_shift(q, sin, cos)
|
||||||
|
kr = theta_shift(k, sin, cos)
|
||||||
|
|
||||||
|
# retention
|
||||||
|
if forward_impl == "parallel":
|
||||||
|
retention_out, curr_kv, retention_weights = self.parallel_retention(
|
||||||
|
qr, kr, v, decay_mask
|
||||||
|
)
|
||||||
|
elif forward_impl == "recurrent":
|
||||||
|
retention_out, curr_kv = self.recurrent_retention(
|
||||||
|
qr,
|
||||||
|
kr,
|
||||||
|
v,
|
||||||
|
decay_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
retention_mask=retention_mask,
|
||||||
|
)
|
||||||
|
elif forward_impl == "chunkwise":
|
||||||
|
retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"forward_impl {forward_impl} not supported.")
|
||||||
|
|
||||||
|
# concaat heads
|
||||||
|
normed = self.group_norm(retention_out).reshape(B, T, self.value_dim)
|
||||||
|
# out gate & proj
|
||||||
|
out = self.gate_fn(g) * normed
|
||||||
|
out = self.out_proj(out)
|
||||||
|
|
||||||
|
outputs = (out, curr_kv)
|
||||||
|
if output_retentions:
|
||||||
|
outputs += (retention_weights,) if forward_impl == "parallel" else (None,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
MultiScaleRetention.forward = MultiScaleRetention_forward
|
||||||
@ -75,6 +75,19 @@ def autocast_forward( func ):
|
|||||||
return wrapper
|
return wrapper
|
||||||
Embedding.forward = autocast_forward(Embedding.forward)
|
Embedding.forward = autocast_forward(Embedding.forward)
|
||||||
|
|
||||||
|
if cfg.fp8.enabled:
|
||||||
|
import transformer_engine.pytorch as te
|
||||||
|
|
||||||
|
Linear = te.Linear
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def autocast():
|
||||||
|
yield te.fp8_autocast(enabled=True)
|
||||||
|
else:
|
||||||
|
@contextmanager
|
||||||
|
def autocast():
|
||||||
|
yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
|
||||||
|
|
||||||
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||||
torch.nn.Linear = Linear
|
torch.nn.Linear = Linear
|
||||||
torch.nn.Embedding = Embedding
|
torch.nn.Embedding = Embedding
|
||||||
@ -83,6 +96,7 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
|||||||
torch.optim.AdamW = AdamW
|
torch.optim.AdamW = AdamW
|
||||||
torch.optim.SGD = SGD
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
def replace_linear( model ):
|
def replace_linear( model ):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user