diff --git a/data/config.yaml b/data/config.yaml index a46cec6..e2016e1 100644 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,3 +1,4 @@ +experimental: False # should probably expand this into a dict of experimental flags sample_rate: 24_000 # 44_000 for dac audio_backend: "vocos" # or dac @@ -11,9 +12,10 @@ models: tones: 1 arch_type: llama training: True - version: 4 + version: 5 attention: flash_attention_2 dropout: 0.1 + experimental: False loss_factors: text: 0.1 @@ -63,11 +65,10 @@ trainer: keep_last_checkpoints: 4 - aggressive_optimizations: False - load_disabled_engines: False + gradient_checkpointing: True - #load_state_dict: True strict_loading: False + #load_state_dict: True #load_tag: "9500" #load_states: False #restart_step_count: True @@ -85,8 +86,6 @@ trainer: amp: False - activation_checkpointing: True - load_webui: False inference: @@ -109,8 +108,6 @@ optimizations: bitnet: False fp8: False -experimental: True # practically required now it seems - dataset: speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" speaker_group_getter: "lambda p: f'{p.parts[-3]}'" diff --git a/vall_e/config.py b/vall_e/config.py index 4adf307..8f2d715 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -191,6 +191,7 @@ class Dataset: def max_duration(self): return self.duration_range[1] +# I really need to clean this up @dataclass() class Model: _max_levels: int = 0 @@ -215,6 +216,7 @@ class Model: dropout: float = 0.1 # adjustable dropout value loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) kv_heads: int = 0 + experimental: bool = False def get(self, name=None): return [ self ] if not name or self.name == name else [] @@ -306,6 +308,10 @@ class Model: def activation_checkpointing(self): return cfg.trainer.activation_checkpointing + @property + def gradient_checkpointing(self): + return cfg.trainer.gradient_checkpointing + @dataclass() class Hyperparameters: batch_size: int = 8 @@ -519,7 +525,8 @@ class Trainer: load_module_only: bool = False restart_step_count: bool = False - activation_checkpointing: bool = True + activation_checkpointing: bool | None = None # deprecated + gradient_checkpointing: bool = True aggressive_optimizations: bool = False check_for_oom: bool = True @@ -728,6 +735,9 @@ class Config(_Config): if self.inference.audio_backend != "" and self.audio_backend == "": self.audio_backend = self.inference.audio_backend + if self.trainer.activation_checkpointing is not None: + self.trainer.gradient_checkpointing = self.trainer.activation_checkpointing + # Preserves the old behavior class NaiveTokenizer: def get_vocab( self ): diff --git a/vall_e/data.py b/vall_e/data.py index 5a58f2f..734d178 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -24,11 +24,144 @@ from typing import Any from torch import Tensor from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data.distributed import DistributedSampler +from torch.nn.utils.rnn import pad_sequence + from tqdm.auto import tqdm # torch.multiprocessing.set_sharing_strategy("file_system") _logger = logging.getLogger(__name__) +# fold into a typical LLM sequence (one embedding rather than split embeddings) +def fold_inputs( + text_list = [], + prom_list = [], + resp_list = [], + + ignore_index = None, + + sep = 3, + stop = 3, + + text_tokens = 256, + audio_tokens = 1024, + audio_rvq_levels = cfg.model.prom_levels +): + def _create_mask(l, device): + seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) + stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) + return (seq < stop).float() # (b t) + + def list_to_tensor(x_list: list[Tensor]): + l = list(map(len, x_list)) + x = pad_sequence(x_list).t() + + m = _create_mask(l, x_list[0].device) + m = m.to(x) + return x, m + + device = text_list[0].device + batch_size = len(text_list) + input_ids = [ [] for _ in range(batch_size) ] + + offset = 0 + + sep = torch.Tensor([ sep ]) + stop = torch.Tensor([ stop ]) + + for i, text in enumerate(text_list): + seq = text.to("cpu", dtype=torch.int64) + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + for i, prom in enumerate(prom_list): + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) + else: + seq = prom.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + (audio_tokens * audio_rvq_levels) + for i, resp in enumerate(resp_list): + seq = resp.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + input_ids[i].append( seq ) + input_ids[i].append( stop ) + + for i, batch in enumerate(input_ids): + input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) + + return list_to_tensor(input_ids) + +# unfold from one unified token ID space to separate token spaces +def unfold_outputs( + output_ids, + + sep = 3, + stop = 3, + + text_tokens = 256, + audio_tokens = 1024, + audio_rvq_levels = cfg.model.prom_levels +): + device = output_ids.device + batch_size = output_ids.shape[0] + + text_list = [ [] for _ in range(batch_size) ] + prom_list = [ [] for _ in range(batch_size) ] + resp_list = [ [] for _ in range(batch_size) ] + + for i, batch in enumerate( output_ids ): + for idx, token in enumerate( batch ): + id = token.item() + if id == sep or id == stop: + continue + + if 0 <= id and id < text_tokens: + text_list[i].append( id ) + elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels): + prom_list[i].append( (id - text_tokens) % audio_tokens ) + elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: + resp_list[i].append( (id - text_tokens) % audio_tokens ) + + prom_len = len(prom_list[i]) + if prom_len % audio_rvq_levels == 0 and False: + prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( prom_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( prom_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + + + resp_len = len(resp_list[i]) + if len(resp_list[i]) % audio_rvq_levels == 0 and False: + resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( resp_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( resp_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + + text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64) + + return dict( + text_list=text_list, + prom_list=prom_list, + resp_list=resp_list + ) + # to-do: clean up this symmap mess def get_phone_symmap(): return cfg.tokenizer.get_vocab() diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 40b806a..6890c62 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -33,7 +33,7 @@ def load_engines(training=True): optimizer = None lr_scheduler = None - inferencing = cfg.mode == "inferencing" or not model._cfg.training + inferencing = cfg.mode == "inferencing" or not model.hyper_config.training backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp @@ -43,7 +43,7 @@ def load_engines(training=True): engine_class = _Engine if backend == "local" or inferencing else Engine if inferencing: - model._cfg.training = False + model.hyper_config.training = False if cfg.optimizations.replace and cfg.optimizations.linear: model.model = ml.replace_linear( model.model ) @@ -83,7 +83,7 @@ def load_engines(training=True): params.update(cfg.hyperparameters.optimizer_params) optimizer = optimizer_class( - [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], + [ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ], **params, ) @@ -96,7 +96,7 @@ def load_engines(training=True): raise ValueError(f'ScheduleFree not implemented with requested optimizer: {cfg.hyperparameters.optimizer}') optimizer = scheduler_class( - [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], + [ param for name, param in model.named_parameters() if name not in model.hyper_config.frozen_params ], lr = params['lr'], warmup_steps = cfg.hyperparameters.warmup_steps ) @@ -144,7 +144,7 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) - _cfg = model._cfg + hyper_config = model.hyper_config # wrap if DDP is requested if ddp: @@ -161,7 +161,7 @@ def load_engines(training=True): optimizer=optimizer, lr_scheduler=lr_scheduler, - _cfg=_cfg, + hyper_config=hyper_config, stats=stats ) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index ae20ad8..9369adc 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -52,9 +52,9 @@ if not distributed_initialized() and cfg.trainer.backend == "local": # and world # A very naive engine implementation using barebones PyTorch class Engine(): def __init__(self, *args, **kwargs): - if '_cfg' in kwargs: - self._cfg = kwargs['_cfg'] - kwargs.pop("_cfg") + if 'hyper_config' in kwargs: + self.hyper_config = kwargs['hyper_config'] + kwargs.pop("hyper_config") self.module = kwargs['model'].to(cfg.device).to(torch.float32 if cfg.trainer.amp else cfg.trainer.dtype) self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None @@ -72,11 +72,11 @@ class Engine(): def freeze(self, freeze_all=True): # set to freeze - if self._cfg is None or not hasattr(self._cfg, "frozen_params"): - raise Exception("freeze_all=False yet self._cfg.frozen_params is None") + if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): + raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") for name, param in self.module.named_parameters(): - if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params): + if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): param.requires_grad_(False) self._frozen_params.add(param) @@ -87,9 +87,9 @@ class Engine(): @property def _training(self): - if not hasattr(self, "_cfg"): + if not hasattr(self, "hyper_config"): return True - return self._cfg.training + return self.hyper_config.training @property def global_step(self): diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 1ef8333..08258ae 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -32,10 +32,10 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed": class Engine(DeepSpeedEngine): def __init__(self, *args, **kwargs): - self._cfg = None - if '_cfg' in kwargs: - self._cfg = kwargs['_cfg'] - kwargs.pop("_cfg") + self.hyper_config = None + if 'hyper_config' in kwargs: + self.hyper_config = kwargs['hyper_config'] + kwargs.pop("hyper_config") kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) @@ -63,11 +63,11 @@ class Engine(DeepSpeedEngine): self.max_nan_losses = 8 def freeze(self, freeze_all=True): - if self._cfg is None or not hasattr(self._cfg, "frozen_params"): - raise Exception("freeze_all=False yet self._cfg.frozen_params is None") + if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): + raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") for name, param in self.module.named_parameters(): - if (freeze_all and param.requires_grad) or (not freeze_all and name in self._cfg.frozen_params): + if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): param.requires_grad_(False) self._frozen_params.add(param) @@ -78,7 +78,7 @@ class Engine(DeepSpeedEngine): @property def _training(self): - return self._cfg.training + return self.hyper_config.training @property def global_step(self): diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 59a86d7..97309ae 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,23 +1,34 @@ from .ar_nar import AR_NAR +from .experimental import Model as Experimental def get_model(cfg, training=True): name = cfg.name - model = AR_NAR( - n_tokens=cfg.tokens, - d_model=cfg.dim, - n_heads=cfg.heads, - n_layers=cfg.layers, - n_experts=cfg.experts, - - p_dropout=cfg.dropout, - - l_padding = cfg.input_alignment, - - training = training, - config = cfg, - ) - model._cfg = cfg + if not cfg.experimental: + model = AR_NAR( + n_tokens=cfg.tokens, + d_model=cfg.dim, + n_heads=cfg.heads, + n_layers=cfg.layers, + n_experts=cfg.experts, + + p_dropout=cfg.dropout, + + l_padding = cfg.input_alignment, + + training = training, + config = cfg, + ) + model._cfg = cfg + else: + model = Experimental( + d_model=cfg.dim, + n_layers=cfg.layers, + n_heads=cfg.heads, + p_dropout=cfg.dropout, + + config = cfg, + ) print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 6a86ca3..f1dd72a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -64,7 +64,7 @@ try: def BitNetTransformerBlock_forward(self, x: Tensor, *args, **kwargs) -> Tensor: skip = x for attn, ffn in zip(self.layers, self.ffn_layers): - if x.requires_grad and self.activation_checkpointing: + if x.requires_grad and self.gradient_checkpointing: x, _ = checkpoint(attn, x, x, x, is_causal=True, *args, **kwargs, use_reentrant=False) else: x, _ = attn(x, x, x, is_causal=True, *args, **kwargs) @@ -83,13 +83,13 @@ try: num_tokens: int, heads=8, ff_mult=4, - activation_checkpointing = True + gradient_checkpointing = True ): super().__init__() self.transformer = BitNetTransformerBlock( dim=dim, depth=depth, heads=heads, ff_mult=ff_mult ) self.norm = BitNetRMSNorm(dim) - self.transformer.activation_checkpointing = activation_checkpointing + self.transformer.gradient_checkpointing = gradient_checkpointing def forward(self, x): x = self.transformer(x) @@ -431,9 +431,9 @@ class Base(nn.Module): return -100 def loss_factor(self, k): - if self.config is None: + if self.hyper_config is None: return 1.0 - return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0 + return self.hyper_config.loss_factors[k] if k in self.hyper_config.loss_factors else 1.0 def __init__( self, @@ -452,8 +452,8 @@ class Base(nn.Module): ): super().__init__() self.training = training - self.config = config - self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True + self.hyper_config = config + self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True self.n_tokens = n_tokens self.d_model = d_model @@ -482,13 +482,13 @@ class Base(nn.Module): self.proms_emb = AudioEmbedding( [n_prom_tokens] * self.n_prom_levels, d_model, levels=self.n_prom_levels if self.version > 3 else None, - sums=self.config.audio_embedding_sums if self.config is not None else True, + sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True, ) # [1025] + [1024] * 8 self.resps_emb = AudioEmbedding( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, levels=self.n_resp_levels if self.version > 3 else None, - sums=self.config.audio_embedding_sums if self.config is not None else True + sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True ) @@ -502,20 +502,20 @@ class Base(nn.Module): self.sep = nn.Parameter(torch.randn(d_model)) # ick, there has to be a better way - hf_attention = self.config.attention if self.config is not None else None + hf_attention = self.hyper_config.attention if self.hyper_config is not None else None - if self.config.attention == "auto": + if self.hyper_config.attention == "auto": if "flash" in AVAILABLE_ATTENTIONS: - self.config.attention = "flash" + self.hyper_config.attention = "flash" elif "xformers" in AVAILABLE_ATTENTIONS: - self.config.attention = "xformers" + self.hyper_config.attention = "xformers" else: - self.config.attention = "mem_efficient" + self.hyper_config.attention = "mem_efficient" - if self.config.attention in ["xformers", "mem_efficient", "math", "flash"]: + if self.hyper_config.attention in ["xformers", "mem_efficient", "math", "flash"]: hf_attention = None - if self.config.attention not in AVAILABLE_ATTENTIONS: - raise ValueError(f"Requesting attention `{self.config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") + if self.hyper_config.attention not in AVAILABLE_ATTENTIONS: + raise ValueError(f"Requesting attention `{self.hyper_config.attention}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") if self.arch_type == "transformer": @@ -538,12 +538,12 @@ class Base(nn.Module): num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, + num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads, hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, attn_implementation=hf_attention, - #gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.gradient_checkpointing, )) else: self.model = MixtralModel(MixtralConfig( @@ -554,7 +554,7 @@ class Base(nn.Module): num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, + num_key_value_heads=self.hyper_config.kv_heads if self.hyper_config.kv_heads > 0 else n_heads, sliding_window=75 * 12, # 12 second context window output_router_logits=training, hidden_act="gelu", @@ -563,10 +563,10 @@ class Base(nn.Module): num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), attn_implementation=hf_attention, - #gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.gradient_checkpointing, )) - if self.activation_checkpointing and not self.model.gradient_checkpointing: + if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) @@ -589,7 +589,7 @@ class Base(nn.Module): is_encoder_decoder=False, is_decoder=True, attn_implementation=hf_attention, - #gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.gradient_checkpointing, )) else: self.model = MixtralModel(MixtralConfig( @@ -609,10 +609,10 @@ class Base(nn.Module): num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), attn_implementation=hf_attention, - #gradient_checkpointing=self.activation_checkpointing, + #gradient_checkpointing=self.gradient_checkpointing, )) - if self.activation_checkpointing and not self.model.gradient_checkpointing: + if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) @@ -628,7 +628,7 @@ class Base(nn.Module): decoder_ffn_embed_dim=d_model * 4, decoder_layers=n_layers, dropout=p_dropout if training else 0.0, - checkpoint_activations=self.activation_checkpointing, + checkpoint_activations=self.gradient_checkpointing, activation_fn="gelu", use_layernorm=self.version < 3, use_biases=self.version < 3, @@ -660,7 +660,7 @@ class Base(nn.Module): decoder_ffn_embed_dim=d_model * 4, decoder_layers=n_layers, dropout=p_dropout if training else 0.0, - checkpoint_activations=self.activation_checkpointing, + checkpoint_activations=self.gradient_checkpointing, activation_fn="gelu", use_glu=False, # self.version >= 3, @@ -673,7 +673,7 @@ class Base(nn.Module): self.model = RetNetDecoder_HF(RetNetConfig_HF(**kwargs)) - if self.activation_checkpointing and not self.model.gradient_checkpointing: + if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) @@ -684,13 +684,13 @@ class Base(nn.Module): depth=n_layers, heads=n_heads, ff_mult=4, - activation_checkpointing=self.activation_checkpointing, + gradient_checkpointing=self.gradient_checkpointing, ) else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') - if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: - self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention ) + if self.hyper_config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: + self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.hyper_config.attention ) self.classifier = nn.Linear(d_model, n_resp_tokens) @@ -877,7 +877,7 @@ class Base(nn.Module): quant_levels: Tensor | None = None, ): # old, "naive" way, no loss factoring - if not self.config.loss_factors: + if not self.hyper_config.loss_factors: target_list = [] for batch in inputs: target = [] diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 606a839..2875ed0 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -1,9 +1,20 @@ +""" +This is an experiment to: +* entertain a thought to try and abide by HF's transformers API (to benefit from caching better) +* conform to a single embedding (instead of a bunch of them) by folding/unfolding inputs +* stop trying to make a mixed AR+NAR model work since it seems lobotomized if I keep trying to enforce both recurrent and parallel inferencing (despite a penalty cost) + + I will not cave and go with codebook patterns, not yet. +""" + from ..config import cfg +from ..data import fold_inputs, unfold_outputs + import torch from torch.nn.utils.rnn import pad_sequence from torch import Tensor from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint import random import math @@ -21,144 +32,40 @@ except Exception as e: pass try: - from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig, MixerModel as MambaMixelModel, layer_norm_fn as MambaLayerNormFn, RMSNorm as MambaRMSNorm + + def MambaMixelModel_forward(self, input_ids, inference_params=None, **mixer_kwargs): + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + if self.gradient_checkpointing and hidden_states.requires_grad: + hidden_states, residual = checkpoint( layer, hidden_states, residual, inference_params=inference_params, use_reentrant=False ) + else: + hidden_states, residual = layer( hidden_states, residual, inference_params=inference_params ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + hidden_states = MambaLayerNormFn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm_f, MambaRMSNorm) + ) + return hidden_states + + MambaMixelModel.forward = MambaMixelModel_forward + AVAILABLE_ARCHES.append("mamba") except Exception as e: print("Error importing `mamba` arch:", e) pass -def _create_mask(l, device): - seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) - stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) - return (seq < stop).float() # (b t) - -def list_to_tensor(x_list: list[Tensor]): - l = list(map(len, x_list)) - x = pad_sequence(x_list).t() - - m = _create_mask(l, x_list[0].device) - m = m.to(x) - return x, m - -# fold into a typical LLM sequence (one embedding rather than split embeddings) -def fold( - text_list = [], - proms_list = [], - resp_list = [], - - ignore_index = None, - - sep = 3, - stop = 3, - - text_tokens = 256, - audio_tokens = 1024, - audio_rvq_levels = cfg.model.prom_levels -): - - device = text_list[0].device - batch_size = len(text_list) - input_ids = [ [] for _ in range(batch_size) ] - - offset = 0 - - sep = torch.Tensor([ sep ]) - stop = torch.Tensor([ stop ]) - - for i, text in enumerate(text_list): - seq = text.to("cpu", dtype=torch.int64) - input_ids[i].append( seq ) - input_ids[i].append( sep ) - - offset = text_tokens - for i, prom in enumerate(proms_list): - if ignore_index is not None: - seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) - else: - seq = prom.flatten().to("cpu", dtype=torch.int64) - for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) - - input_ids[i].append( seq ) - input_ids[i].append( sep ) - - offset = text_tokens + (audio_tokens * audio_rvq_levels) - for i, resp in enumerate(resp_list): - seq = resp.flatten().to("cpu", dtype=torch.int64) - for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) - input_ids[i].append( seq ) - input_ids[i].append( stop ) - - for i, batch in enumerate(input_ids): - input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) - - return list_to_tensor(input_ids) - -# unfold from one unified token ID space to separate token spaces -def unfold( - input_ids, - - sep = 3, - stop = 3, - - text_tokens = 256, - audio_tokens = 1024, - audio_rvq_levels = cfg.model.prom_levels -): - device = input_ids.device - batch_size = input_ids.shape[0] - - text_list = [ [] for _ in range(batch_size) ] - prom_list = [ [] for _ in range(batch_size) ] - resp_list = [ [] for _ in range(batch_size) ] - - for i, batch in enumerate( input_ids ): - for idx, token in enumerate( batch ): - id = token.item() - if id == sep or id == stop: - continue - - if 0 <= id and id < text_tokens: - text_list[i].append( id ) - elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels): - prom_list[i].append( (id - text_tokens) % audio_tokens ) - elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: - resp_list[i].append( (id - text_tokens) % audio_tokens ) - - prom_len = len(prom_list[i]) - if prom_len % audio_rvq_levels == 0 and False: - prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() - else: - bins = [ [] for _ in range(audio_rvq_levels) ] - for pos in range( prom_len ): - rvq = pos % audio_rvq_levels - bins[rvq].append( prom_list[i][pos] ) - nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels - bins = bins[:nearest] - prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) - - - resp_len = len(resp_list[i]) - if len(resp_list[i]) % audio_rvq_levels == 0 and False: - resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() - else: - bins = [ [] for _ in range(audio_rvq_levels) ] - for pos in range( resp_len ): - rvq = pos % audio_rvq_levels - bins[rvq].append( resp_list[i][pos] ) - nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels - bins = bins[:nearest] - resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) - - text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64) - - return dict( - text_list=text_list, - prom_list=prom_list, - resp_list=resp_list - ) - SELECTED_ARCH = cfg.model.arch_type if SELECTED_ARCH not in AVAILABLE_ARCHES: @@ -179,9 +86,12 @@ class Model(LlmArchClass): n_heads=16, p_dropout=0.1, - attention_backend=None, - activation_checkpointing=True, + config = None, ): + self.hyper_config = config + + hf_attention = config.attention if config is not None else None + gradient_checkpointing = config.gradient_checkpointing if config is not None else True if SELECTED_ARCH == "llama": super().__init__(config=LlamaConfig( @@ -197,10 +107,10 @@ class Model(LlmArchClass): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=attention_backend, + attn_implementation=hf_attention, )) - if activation_checkpointing: + if gradient_checkpointing: self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False )) @@ -209,9 +119,11 @@ class Model(LlmArchClass): vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, d_model=d_model, n_layer=n_layers*2, - #ssm_cfg={"layer": "Mamba2"}, + #ssm_cfg={"layer": "Mamba2"}, # will ALWAYS nan )) + self.backbone.gradient_checkpointing = gradient_checkpointing + def forward( self, @@ -293,9 +205,9 @@ def example_usage(): proms_list = proms_list[:1] resps_list = resps_list[:1] - input_ids, attention_mask = fold(text_list, proms_list, resps_list) - target_ids, target_attention_mask = fold(text_list, proms_list, resps_list, ignore_index=-100) - prefix_input_ids, prefix_attention_mask = fold(text_list, proms_list) + input_ids, attention_mask = fold_inputs(text_list, proms_list, resps_list) + target_ids, target_attention_mask = fold_inputs(text_list, proms_list, resps_list, ignore_index=-100) + prefix_input_ids, prefix_attention_mask = fold_inputs(text_list, proms_list) kwargs = {} model = Model(**kwargs).to(device) @@ -373,7 +285,7 @@ def example_usage(): else: output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False) - unfolded = unfold( output ) + unfolded = unfold_outputs( output ) for i, batch in enumerate(unfolded["resp_list"]): _ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) diff --git a/vall_e/models/transformer.py b/vall_e/models/transformer.py index 4d749c1..a2093f4 100755 --- a/vall_e/models/transformer.py +++ b/vall_e/models/transformer.py @@ -15,7 +15,29 @@ from torch import Tensor, einsum, nn from torch.utils.checkpoint import checkpoint from ..utils import wrapper as ml -from .adaln import AdaLN + +class AdaLN(nn.Module): + def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2): + super().__init__() + self.eps = eps + self.emb = nn.Embedding(n_levels, d_model * 2) + self.k = k + self.c = c + nn.init.zeros_(self.emb.weight) + + def forward(self, x, l): + h = F.layer_norm(x, x.shape[-1:], eps=self.eps) + + # The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135) + # performed worse than vanilla LayerNorm. + # The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN. + # Did they use AdaNorm inside AdaLN? (as follows) + h = self.c * (1 - (self.k * h).detach()) * h + + logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1) + y = logγ.exp() * h + β + + return y class SinusoidalEmbedding(nn.Module): def __init__(self, d_model): diff --git a/vall_e/train.py b/vall_e/train.py index a9339e8..0dec17c 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -5,6 +5,7 @@ from .data import create_train_val_dataloader from .emb import qnt from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc +from .data import fold_inputs, unfold_outputs import auraloss import json @@ -25,12 +26,29 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") def train_feeder(engine, batch): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): - engine( - text_list=batch["text"], - proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level - resps_list=batch["resps"], - lang_list=batch["lang"], - ) + if engine.hyper_config.experimental: + input_ids, attention_mask = fold_inputs( + text_list=batch["text"], + prom_list=batch["proms"], + resp_list=batch["resps"], + ) + target_ids, target_attention_mask = fold_inputs( + text_list=batch["text"], + prom_list=batch["proms"], + resp_list=batch["resps"], + ignore_index=-100 + ) + engine( + input_ids=input_ids, + labels=target_ids + ) + else: + engine( + text_list=batch["text"], + proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level + resps_list=batch["resps"], + lang_list=batch["lang"], + ) losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") @@ -48,22 +66,6 @@ def train_feeder(engine, batch): @torch.inference_mode() def run_eval(engines, eval_name, dl): - AR = None - NAR = None - AR_NAR = None - - names = [] - for name, engine in engines.items(): - if name[:6] == "ar+nar": - AR_NAR = engine - elif name[:2] == "ar": - AR = engine - elif name[:3] == "nar": - NAR = engine - else: - continue - names.append(name) - stats = defaultdict(list) stats['loss'] = [] @@ -101,44 +103,22 @@ def run_eval(engines, eval_name, dl): batch: dict = to_device(next(iter(dl)), cfg.device) processed += len(batch["text"]) - # if we're training both models, provide output for both - if AR is not None and NAR is not None: - name = "+".join(names) + for name in engines: + engine = engines[name] - resps_list = AR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) - resps_list = [ r.unsqueeze(-1) for r in resps_list ] - resps_list = NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) + if engine.hyper_config.experimental: + input_ids, attention_mask = fold_inputs( + text_list=batch["text"], + proms_list=batch["proms"], + ) + output = engine.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) + resps_list = unfold_outputs( output )["resp_list"] + else: + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) + resps_list = [ r.unsqueeze(-1) for r in resps_list ] + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) process( name, batch, resps_list ) - else: - for name in engines: - model = engines[name] - - if name.startswith("ar+nar"): - resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) - resps_list = [ r.unsqueeze(-1) for r in resps_list ] - resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature) - elif name.startswith("ar"): - resps_list = model( - text_list=batch["text"], - proms_list=batch["proms"], - lang_list=batch["lang"], - max_steps=cfg.evaluation.steps, - sampling_temperature=cfg.evaluation.ar_temperature, - ) - resps_list = [r.unsqueeze(-1) for r in resps_list] - elif name.startswith("nar"): - resps_list = model( - text_list=batch["text"], - proms_list=batch["proms"], - lang_list=batch["lang"], - resps_list=[r[..., 0].unsqueeze(-1) for r in batch["resps"]], - sampling_temperature=cfg.evaluation.nar_temperature, - ) - else: - raise NotImplementedError(name) - - process( name, batch, resps_list ) stats = {k: sum(v) / len(v) for k, v in stats.items()}