vall-e/vall_e/models/experimental.py

445 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
This is an experiment to:
* entertain a thought to try and abide by HF's transformers API (to benefit from caching better)
* conform to a single embedding (instead of a bunch of them) by folding/unfolding inputs
* stop trying to make a mixed AR+NAR model work since it seems lobotomized if I keep trying to enforce both recurrent and parallel inferencing (despite a penalty cost)
+ I will not cave and go with codebook patterns, not yet.
"""
from ..config import cfg
from ..data import fold_inputs, unfold_outputs
import torch
from torch.nn.utils.rnn import pad_sequence
from torch import Tensor
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
import random
import math
from einops import rearrange
from tqdm import trange
AVAILABLE_ARCHES = []
try:
from transformers import LlamaForCausalLM, LlamaConfig
AVAILABLE_ARCHES.append("llama")
except Exception as e:
print("Error importing `llama` arch:", e)
pass
try:
from .retnet_hf import RetNetConfig
from ..ext.retnet_hf.modeling_retnet import RetNetForCausalLM
AVAILABLE_ARCHES.append("retnet")
except Exception as e:
print("Error importing `retnet` arch:", e)
pass
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm
def MambaMixelModel_forward(self, input_ids, inference_params=None, **mixer_kwargs):
hidden_states = self.embedding(input_ids)
residual = None
for layer in self.layers:
if self.gradient_checkpointing and hidden_states.requires_grad:
hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False )
else:
hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params )
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = MambaLayerNormFn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, MambaRMSNorm)
)
return hidden_states
MambaMixelModel.forward = MambaMixelModel_forward
AVAILABLE_ARCHES.append("mamba")
except Exception as e:
print("Error importing `mamba` arch:", e)
pass
SELECTED_ARCH = cfg.model.arch_type
if SELECTED_ARCH not in AVAILABLE_ARCHES:
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
if SELECTED_ARCH == "mamba":
LlmArchClass = MambaLMHeadModel
elif SELECTED_ARCH == "llama":
LlmArchClass = LlamaForCausalLM
elif SELECTED_ARCH == "retnet":
LlmArchClass = RetNetForCausalLM
else:
raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available")
class Model(LlmArchClass):
def __init__(
self,
d_model=1024,
n_layers=12,
n_heads=16,
p_dropout=0.1,
config = None,
):
self.hyper_config = config
hf_attention = config.attention if config is not None else None
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
vocab_size = 256 + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
if SELECTED_ARCH == "llama":
super().__init__(config=LlamaConfig(
vocab_size=vocab_size,
hidden_size=d_model,
max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.max_levels * 60, # max-length of 60 seconds
intermediate_size=d_model*4,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout,
num_key_value_heads=n_heads,
sliding_window=cfg.dataset.frames_per_second * cfg.model.max_levels * 12,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=hf_attention,
))
if gradient_checkpointing:
self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
elif SELECTED_ARCH == "retnet":
super().__init__(config=RetNetConfig(
vocab_size=vocab_size,
decoder_embed_dim=d_model,
decoder_value_embed_dim =d_model * 2,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
checkpoint_activations=gradient_checkpointing,
activation_fn="gelu",
use_layernorm=False,
use_biases=False,
use_glu=True,
#chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
#recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
#no_output_layer=True,
#rotary_embedding_base=self.rotary_embedding_base, # 10000
decoder_normalize_before=True,
))
elif SELECTED_ARCH == "mamba":
super().__init__(config=MambaConfig(
vocab_size=vocab_size,
d_model=d_model,
n_layer=n_layers*2,
#ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan
))
self.backbone.gradient_checkpointing = gradient_checkpointing
def generate(
self,
*args,
**kwargs
):
if SELECTED_ARCH == "mamba":
kwargs["cg"] = True
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
if "do_sample" in kwargs:
kwargs.pop("do_sample")
return super().generate(*args, **kwargs)
def forward(
self,
*args,
**kwargs,
):
if SELECTED_ARCH == "mamba":
if "attention_mask" in kwargs:
kwargs.pop("attention_mask")
output = super().forward(*args, **kwargs)
if SELECTED_ARCH in ["llama", "retnet"]:
if output.loss is not None:
self.loss = dict(
nll = output.loss,
)
elif SELECTED_ARCH == "mamba":
if "labels" in kwargs:
labels = kwargs.pop("labels")
logits = output.logits
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
self.loss = dict(
nll = loss,
)
return output
def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
from functools import partial
from einops import repeat
from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model
from ..engines import Engine
from ..utils import wrapper as ml
import numpy as np
import re
device = "cuda"
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.max_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
]
prom_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resp_list = [
qnt[:, :].to(device),
#qnt[cfg.dataset.frames_per_second:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list[:1]
prom_list = prom_list[:1]
resp_list = resp_list[:1]
if False:
output_list = [ [] ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[0])
unfolded = unfold_outputs( input_ids, quant_levels=[0])
print( 0, "inputs:", input_ids.shape, input_ids )
print( 0, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 0] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[1])
unfolded = unfold_outputs( input_ids, quant_levels=[1])
print( 1, "inputs:", input_ids.shape, input_ids )
print( 1, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 1] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[2])
unfolded = unfold_outputs( input_ids, quant_levels=[2])
print( 2, "inputs:", input_ids.shape, input_ids )
print( 2, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 2] )
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=output_list, targ_list=resp_list, quant_levels=[3])
unfolded = unfold_outputs( input_ids, quant_levels=[3])
print( 3, "inputs:", input_ids.shape, input_ids )
print( 3, "outputs:", unfolded["resp_list"][0].shape, unfolded["resp_list"][0] )
output_list[0].append( resp_list[0][:, 3] )
return
kwargs = {}
model = Model(**kwargs).to(device)
steps = 50 if cfg.model.interleave else 250
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
engine = Engine(model=model, optimizer=optimizer)
torch.save( {
'module': model.state_dict()
}, f"./data/{SELECTED_ARCH}.pth" )
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
engine.eval()
target_length = 0
resp_list = None
if cfg.model.interleave:
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list)
output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False)
unfolded = unfold_outputs( output )
resp_list = unfolded["resp_list"]
else:
resp_list = [ [] for _ in range(len(text_list)) ]
for l in range(cfg.model.max_levels):
quant_levels = [ [ l ] for _ in range(len(text_list)) ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels, experimental=True)
min_length = len(input_ids[0]) + 1
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
min_length=min_length,
max_length=min_length+steps*2,
eos_token_id=3,
do_sample=False
)
unfolded = unfold_outputs( output, quant_levels=quant_levels )
if l == 0:
steps = 0
for batch, resp in enumerate(unfolded["resp_list"]):
length = resp.shape[-1]
# store length
if l == 0:
steps = max( steps, length )
# pad
else:
resp = resp[:steps]
if length < steps:
resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ])
resp_list[batch].append( resp )
for i, resp in enumerate( resp_list ):
resp_list[i] = torch.stack( resp ).t()
for i, batch in enumerate(resp_list):
_ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
else:
resps_list = [ resp for resp in resp_list ]
input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels)
target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels)
stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
torch.save( {
'module': model.state_dict()
}, f"./data/{SELECTED_ARCH}.pth" )
#sample("init", 5)
train()
sample("final")
if __name__ == "__main__":
example_usage()