From a65c8144f4b41725b6dc742e7ba3e79be72088ac Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 13 Feb 2025 18:38:40 -0600 Subject: [PATCH] with the amount of tweaks I keep making I could have probably had the nvidia/audio-codec-44khz model realized already...... --- vall_e/config.py | 7 +++ vall_e/engines/__init__.py | 5 ++- vall_e/engines/base.py | 17 +++++++ vall_e/models/ar_nar.py | 8 ++-- vall_e/models/base.py | 91 ++++++++++++++++++++++++++++---------- vall_e/utils/wrapper.py | 2 +- 6 files changed, 100 insertions(+), 30 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index becd445..2359244 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -414,6 +414,13 @@ class Model: return 16 return 12 + @property + def ffn(self): + if isinstance(self.size, dict) and hasattr(self.size, "ffn"): + return self.size['ffn'] + + return 4 + @property def activation_checkpointing(self): return cfg.trainer.activation_checkpointing diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 3c514c8..8d9c133 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -396,11 +396,12 @@ def load_engines(training=True, **model_kwargs): if cfg.lora is not None: key_name = cfg.lora.full_name - kwargs['id'] = 'job' + salt = "run" + kwargs['id'] = f'{key_name}-{salt}' kwargs['resume'] = 'allow' if world_size() > 1: kwargs["group"] = "DDP" - kwargs['id'] = f'job-{global_rank()}' + kwargs['id'] = f'{key_name}-{salt}-{global_rank()}' engine.wandb = wandb.init(project=key_name, **kwargs) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 1da2016..880c6aa 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -37,6 +37,7 @@ import time import torch import torch.distributed import os +import re from torch import Tensor from torch.distributed import all_reduce @@ -597,6 +598,22 @@ class Engines(dict[str, Engine]): if engine.wandb is not None: engine.wandb.log(model_stats, step=engine.global_step) + filtered_keys = [ k for k in model_stats.keys() if "[" in k ] + filtered_values = {} + for k in filtered_keys: + v = model_stats[k] + del model_stats[k] + + nk = re.sub(r"\[\d+\]", "", k) + + if nk not in filtered_values: + filtered_values[nk] = [] + + filtered_values[nk].append( v ) + + for k, v in filtered_values.items(): + model_stats[k] = sum(v) / len(v) + model_stats = model_stats | dict( lr=engine.get_lr()[0], elapsed_time=elapsed_time, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 8dd6a04..326f96f 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -171,7 +171,7 @@ class AR_NAR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - if quant_level <= 0 and timesteps[i] is None: + if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None): # append stop tokens for AR if task not in text_task: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) @@ -1232,7 +1232,7 @@ def example_usage(): 'n_text_tokens': cfg.model.text_tokens, 'n_audio_tokens': cfg.model.audio_tokens, - 'd_model': 1024, # 256, # 1024, # 1536 + 'd_model': 1536, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 'n_experts': 1 if not cfg.model else cfg.model.experts, @@ -1254,7 +1254,7 @@ def example_usage(): available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) - steps = 500 // batch_size + steps = 250 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" @@ -1444,7 +1444,7 @@ def example_usage(): """ for task in available_tasks: - sample("final", task=task) + sample("final", task="tts-nar") engines.quit() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 15b5b75..b1bb6a0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -135,7 +135,7 @@ def _interleave_sequence_flatten( input: list[torch.Tensor] ): # automagically parses a batch-list and returns it as a list """ -class Embedding(nn.Embedding): +class Embedding(ml.Embedding): def forward(self, x_list: list[Tensor]) -> list[Tensor]: if len(x_list) == 0: return [] @@ -192,7 +192,7 @@ class AudioEmbedding_Old(nn.Module): # array of embeddings # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) + self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) self.weight = nn.ParameterList([nn.Parameter( torch.tensor([1]) ) for i in range(levels)]) if levels is not None else None @@ -223,7 +223,7 @@ class AudioEmbedding(nn.Module): # array of embeddings # proms are [0, resp_levels] # resp are split to where [0] is for the AR, and [1:] are reserved for NAR - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) + self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens]) # further experimentation is needed to see if this actually is useful self.sums = sums # @@ -350,7 +350,7 @@ class AudioEncoder(nn.Module): token_dim: int, ): super().__init__() - self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)]) self.proj = nn.Linear(8 * token_dim, 1 * token_dim) def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: @@ -538,6 +538,7 @@ class Base(nn.Module): n_raw_text_tokens: int = 8575, d_model: int = 512, + d_ffn: int = 4, n_heads: int = 8, n_layers: int = 12, p_dropout: float = 0.1, @@ -735,13 +736,15 @@ class Base(nn.Module): if self.version >= 6: self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) + self.resp_parallel_training = True # governs if all levels are trained in parallel or one per sample like the old way + self.monolithic_audio_encoder = False # monolithic sounds bad if self.version >= 7: pd_model = d_model // 4 - pd_ffn = pd_model * 4 + pd_ffn = pd_model * d_ffn pd_heads = n_heads // 4 pd_layers = 1 - if False: + if self.monolithic_audio_encoder: self.audio_emb = AudioEncoder( n_tokens=n_audio_tokens + 1, # masked token n_levels=self.n_resp_levels, @@ -763,7 +766,7 @@ class Base(nn.Module): self.n_resp_levels, d_model, dict( - vocab_size=n_audio_tokens, + vocab_size=n_audio_tokens + 1, hidden_size=pd_model, max_position_embeddings=max_position_embeddings, intermediate_size=pd_ffn, @@ -821,7 +824,7 @@ class Base(nn.Module): vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, - intermediate_size=d_model*4, + intermediate_size=d_model*d_ffn, num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, @@ -1134,7 +1137,7 @@ class Base(nn.Module): inputs[i].append( ( "resp", resps_list[i] ) ) if self.version >= 7: - classifier_level = f"NAR:{quant_level}:{quant_level}" + classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}" inputs[i].append( ("classifier_level", classifier_level) ) # Audio length prediction task @@ -1530,7 +1533,7 @@ class Base(nn.Module): return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens - if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): + if self.version < 4 or (self.version >= 5 and self.version < 7 and self.config and self.config.experimental.audio_embedding_sums): return torch.full_like(input[..., 0], self.ignore_index) if self.version < 7: @@ -1562,7 +1565,7 @@ class Base(nn.Module): else: accuracy_metric = MulticlassAccuracy( logit.shape[-1], - top_k = 10, + top_k = min(logit.shape[0], 10), average="micro", multidim_average="global", ignore_index = -100 @@ -1610,6 +1613,9 @@ class Base(nn.Module): proms = [ input ] if isinstance(input, torch.Tensor) else input # iterate over the list to inject their tokens token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) + + if logits[batch_index].dim() < 3 and token.dim() >= 2: + token = token[..., 0] elif name == "resp": # mask found, apply it if self.version < 7: @@ -1659,9 +1665,24 @@ class Base(nn.Module): if loss_factor == 0.0: continue + + # cringe way to deduce "requested" level + level = quant_level + for i in range( self.n_resp_levels ): + if classifier_level == f'NAR:{i}:{i}': + level = i + break if logits[batch_index].dim() < 3: nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal ) + + if name == "resp": + name = f'{name}[{quant_level}]' + elif not self.resp_parallel_training: + if name == "resp": + name = f'{name}[{level}]' + sequence = token if token.dim() <= 1 else token[:, level] + nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal ) else: nlls = [] accs = [] @@ -1670,15 +1691,29 @@ class Base(nn.Module): sequence = token if token.dim() <= 1 else token[:, level] nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal ) - if nll: - nlls.append( nll ) - if metrics: - accs.append( metrics ) + if name == "resp": + if nll is not None: + if f'{name}[{level}].nll' not in loss: + loss[f'{name}[{level}].nll'] = [] + loss[f"{name}[{level}].nll"].append( nll * loss_factor ) + + if metrics is not None: + if f'{name}[{level}].acc' not in stats: + stats[f'{name}[{level}].acc'] = [] + stats[f"{name}[{level}].acc"].append( metrics ) - if nlls: - nll = sum(nlls) / len(nlls) - if accs: - accs = sum(accs) / len(accs) + nll = None + metrics = None + else: + if nll: + nlls.append( nll ) + if metrics: + accs.append( metrics ) + else: + if nlls: + nll = sum(nlls) / len(nlls) + if accs: + accs = sum(accs) / len(accs) if nll is not None: if f'{name}.nll' not in loss: @@ -1698,6 +1733,16 @@ class Base(nn.Module): if logits[batch_index].dim() < 3: sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) nll, metrics = _calc_loss( logits[batch_index], sequence, causal ) + elif not self.resp_parallel_training: + # cringe way to deduce "requested" level + level = 0 + for i in range( self.n_resp_levels ): + if classifier_level == f'NAR:{i}:{i}': + level = i + break + sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] + sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) + nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal ) else: nlls = [] accs = [] @@ -1779,7 +1824,7 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None classifier_levels = self.get_input( inputs, name="classifier_level" ) - causal_levels = [ "AR:0:0", "stt", "len", "phn" ] + causal_levels = [ "stt", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ] # right now limit to new versions because I need to retrain the model for noncausal masks... is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] @@ -1800,7 +1845,7 @@ class Base(nn.Module): if self.version >= 7: logits = [ logit for logit in logits ] - audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ] + audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ] decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ] classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ] @@ -2157,7 +2202,7 @@ if __name__ == "__main__": if is_from_pretrained: n_vocab = end - start - embds[k] = torch.nn.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) + embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight) embds[k].weight[:] = model.embed_tokens.weight[start:end, :] if classifier_idx >= 0: @@ -2169,7 +2214,7 @@ if __name__ == "__main__": heads[k].weight[:] = hf_model.lm_head.weight[start:end, :] else: embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name] - embds[k] = torch.nn.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) + embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype) embds[k].load_state_dict({ "weight": embd_weight }) if classifier_idx >= 0: diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index cb40e9b..5c86ede 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -32,7 +32,7 @@ if cfg.optimizations.bitsandbytes: Linear = bnb.nn.Linear8bitLt if cfg.optimizations.embedding: - Embedding = bnb.nn.modules.Embedding + Embedding = bnb.nn.StableEmbedding """ Embedding.forward = lambda self, input: ( self.norm(F.embedding( input,