This commit is contained in:
mrq 2024-06-04 00:07:00 -05:00
parent 186b93a77e
commit c93d5863fd
8 changed files with 178 additions and 153 deletions

View File

@ -44,7 +44,7 @@ def fold_inputs(
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
audio_rvq_levels = cfg.model.max_levels
):
def _create_mask(l, device):
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -107,7 +107,7 @@ def unfold_outputs(
text_tokens = 256,
audio_tokens = 1024,
audio_rvq_levels = cfg.model.prom_levels
audio_rvq_levels = cfg.model.max_levels
):
device = output_ids.device
batch_size = output_ids.shape[0]
@ -139,7 +139,7 @@ def unfold_outputs(
bins[rvq].append( prom_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
resp_len = len(resp_list[i])
@ -152,9 +152,9 @@ def unfold_outputs(
bins[rvq].append( resp_list[i][pos] )
nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels
bins = bins[:nearest]
resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64)
resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64)
text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64)
return dict(
text_list=text_list,

View File

@ -963,6 +963,7 @@ class RetNetModel(RetNetPreTrainedModel):
retention_mask,
forward_impl,
past_key_value,
use_reentrant=True,
)
else:
layer_outputs = layer(

View File

@ -1,10 +1,9 @@
from .ar_nar import AR_NAR
from .experimental import Model as Experimental
def get_model(cfg, training=True):
name = cfg.name
if not cfg.experimental:
from .ar_nar import AR_NAR
model = AR_NAR(
n_tokens=cfg.tokens,
d_model=cfg.dim,
@ -21,6 +20,7 @@ def get_model(cfg, training=True):
)
model._cfg = cfg
else:
from .experimental import Model as Experimental
model = Experimental(
d_model=cfg.dim,
n_layers=cfg.layers,

View File

@ -386,7 +386,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 200
steps = 50
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
@ -448,7 +448,7 @@ def example_usage():
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
}, f"./data/{cfg.model.arch_type}.pth" )
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@ -459,16 +459,11 @@ def example_usage():
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
if cfg.audio_backend != "dac":
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
for i, o in enumerate(resps_list):
_ = decode_to_file(o, f"data/ar+nar.{i}.{name}.wav", device=device)
_ = decode_to_file(o, f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()
@ -484,7 +479,7 @@ def example_usage():
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
}, f"./data/{cfg.model.arch_type}.pth" )
sample("init", 5)
train()

View File

@ -31,6 +31,15 @@ except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from .retnet_hf import RetNetConfig
from ..ext.retnet_hf.modeling_retnet import RetNetForCausalLM
AVAILABLE_ARCHES.append("retnet")
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
@ -75,6 +84,8 @@ if SELECTED_ARCH == "mamba":
LlmArchClass = MambaLMHeadModel
elif SELECTED_ARCH == "llama":
LlmArchClass = LlamaForCausalLM
elif SELECTED_ARCH == "retnet":
LlmArchClass = RetNetForCausalLM
else:
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
@ -92,18 +103,19 @@ class Model(LlmArchClass):
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
if SELECTED_ARCH == "llama":
super().__init__(config=LlamaConfig(
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
vocab_size=vocab_size,
hidden_size=d_model,
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.prom_levels * 60, # max-length of 60 seconds
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
sliding_window=cfg.dataset.frames_per_second * cfg.model.prom_levels * 12,
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
@ -114,9 +126,31 @@ class Model(LlmArchClass):
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif SELECTED_ARCH == "retnet":
super().__init__(config=RetNetConfig(
vocab_size=vocab_size,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
checkpoint_activations=gradient_checkpointing,
activation_fn="gelu",
use_layernorm=False,
use_biases=False,
use_glu=True,
#chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
#recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
#no_output_layer=True,
#rotary_embedding_base=self.rotary_embedding_base, # 10000
decoder_normalize_before=True,
))
elif SELECTED_ARCH == "mamba":
super().__init__(config=MambaConfig(
vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1,
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layers*2,
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
@ -132,15 +166,15 @@ class Model(LlmArchClass):
):
output = super().forward(*args, **kwargs)
if SELECTED_ARCH == "llama":
if SELECTED_ARCH in ["llama", "retnet"]:
if output.loss is not None:
self.loss = dict(
nll = output.loss,
)
elif SELECTED_ARCH == "mamba":
if "labels" in kwargs:
logits = output.logits
labels = kwargs.pop("labels")
logits = output.logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
@ -183,7 +217,7 @@ def example_usage():
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
@ -278,7 +312,7 @@ def example_usage():
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=cfg.model.prom_levels*cfg.dataset.frames_per_second*60 ):
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*60 ):
engine.eval()
if SELECTED_ARCH == "mamba":
output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3)

View File

@ -1,6 +1,6 @@
# https://github.com/syncdoth/RetNet/
from ..ext.retnet_hf.configuration_retnet import RetNetConfig
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder
from ..ext.retnet_hf.modeling_retnet import RetNetModel as RetNetDecoder, RetNetForCausalLM
# things we're overriding or required to override
from ..ext.retnet_hf.modeling_retnet import RetNetDecoderLayer, MultiScaleRetention, theta_shift, split_heads, RMSNorm, FeedForwardNetwork, get_activation_fn, LayerNorm, RetNetRelPos
@ -32,77 +32,74 @@ def FeedForwardNetwork_init(
FeedForwardNetwork.__init__ = FeedForwardNetwork_init
# removes embed_tokens
def RetNetModel_init(
self,
config: RetNetConfig,
embed_tokens: torch.nn.Embedding = None,
tensor_parallel: bool = False,
):
super(RetNetDecoder, self).__init__(config)
self.config = config
self,
config: RetNetConfig,
embed_tokens: torch.nn.Embedding = None,
tensor_parallel: bool = False,
):
super(RetNetDecoder, self).__init__(config)
self.config = config
self.dropout_module = torch.nn.Dropout(config.dropout)
self.dropout_module = torch.nn.Dropout(config.dropout)
self.embed_dim = config.decoder_embed_dim
self.embed_scale = (
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
self.embed_dim = config.decoder_embed_dim
self.embed_scale = (
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
)
if embed_tokens is None and config.vocab_size:
embed_tokens = torch.nn.Embedding(
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
)
self.embed_tokens = embed_tokens
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layernorm_embedding = None
self.layers = torch.nn.ModuleList([])
for i in range(config.decoder_layers):
self.layers.append(
RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
)
"""
if embed_tokens is None:
embed_tokens = torch.nn.Embedding(
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
)
"""
self.embed_tokens = None
self.decoder_layers = len(self.layers)
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layernorm_embedding = None
if config.decoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layer_norm = None
self.layers = torch.nn.ModuleList([])
self.retnet_rel_pos = RetNetRelPos(config)
self.recurrent_chunk_size = config.recurrent_chunk_size
for i in range(config.decoder_layers):
self.layers.append(
RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
)
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
self.decoder_layers = len(self.layers)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
if config.decoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(config)
self.recurrent_chunk_size = config.recurrent_chunk_size
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
self.gradient_checkpointing = True
self.post_init()
self.gradient_checkpointing = True
self.post_init()
RetNetDecoder.__init__ = RetNetModel_init

View File

@ -36,83 +36,81 @@ FeedForwardNetwork.__init__ = FeedForwardNetwork_init
# removes embed_tokens
def RetNetModel_init(
self, config, embed_tokens=None, output_projection=None, **kwargs
):
super(RetNetDecoder, self).__init__(**kwargs)
self.config = config
self, config, embed_tokens=None, output_projection=None, **kwargs
):
super(RetNetDecoder, self).__init__(**kwargs)
self.config = config
self.dropout_module = torch.nn.Dropout(config.dropout)
self.dropout_module = torch.nn.Dropout(config.dropout)
self.embed_dim = config.decoder_embed_dim
self.embed_scale = (
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
self.embed_dim = config.decoder_embed_dim
self.embed_scale = (
1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
)
if embed_tokens is None and config.vocab_size:
embed_tokens = torch.nn.Embedding(
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
)
self.embed_tokens = embed_tokens
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
self.output_projection = self.build_output_projection(config)
else:
self.output_projection = output_projection
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layernorm_embedding = None
self.layers = torch.nn.ModuleList([])
for i in range(config.decoder_layers):
layer = self.build_decoder_layer(
config,
depth=i,
)
"""
if embed_tokens is None:
embed_tokens = torch.nn.Embedding(
config.vocab_size, config.decoder_embed_dim, config.pad_token_id
)
if config.checkpoint_activations:
layer = checkpoint_wrapper(layer)
"""
self.embed_tokens = None
self.layers.append(layer)
if (output_projection is None and not config.no_output_layer and config.vocab_size > 0):
self.output_projection = self.build_output_projection(config)
else:
self.output_projection = output_projection
self.num_layers = len(self.layers)
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layernorm_embedding = None
if config.decoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layer_norm = None
self.layers = torch.nn.ModuleList([])
self.retnet_rel_pos = RetNetRelPos(config)
self.chunkwise_recurrent = config.chunkwise_recurrent
self.recurrent_chunk_size = config.recurrent_chunk_size
for i in range(config.decoder_layers):
layer = self.build_decoder_layer(
config,
depth=i,
)
"""
if config.checkpoint_activations:
layer = checkpoint_wrapper(layer)
"""
self.layers.append(layer)
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
self.num_layers = len(self.layers)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
if config.decoder_normalize_before:
self.layer_norm = LayerNorm(self.embed_dim, eps=config.layernorm_eps) # RMSNorm
else:
self.layer_norm = None
self.retnet_rel_pos = RetNetRelPos(config)
self.chunkwise_recurrent = config.chunkwise_recurrent
self.recurrent_chunk_size = config.recurrent_chunk_size
if config.deepnorm:
init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.div_(init_scale)
if config.subln and not config.use_glu:
init_scale = math.sqrt(math.log(config.decoder_layers * 2))
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
or "v_proj" in name
):
p.data.mul_(init_scale)
self.gradient_checkpointing = True
self.gradient_checkpointing = True
RetNetDecoder.__init__ = RetNetModel_init

View File

@ -109,9 +109,9 @@ def run_eval(engines, eval_name, dl):
if engine.hyper_config.experimental:
input_ids, attention_mask = fold_inputs(
text_list=batch["text"],
proms_list=batch["proms"],
prom_list=batch["proms"],
)
output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False)
resps_list = unfold_outputs( output )["resp_list"]
else:
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)