From 5cd71ef238f351fd51f4b6be2d97f7337f7dc72a Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 6 Mar 2025 14:48:14 -0600 Subject: [PATCH] QoL so I can stop having to manually inject different configs --- docs/train.md | 4 +- vall_e/config.py | 17 +++-- vall_e/models/ar_nar_v2.py | 2 +- vall_e/models/base_v2.py | 152 ++++++++++++++++++++----------------- 4 files changed, 97 insertions(+), 78 deletions(-) diff --git a/docs/train.md b/docs/train.md index a841ecb..23220b4 100644 --- a/docs/train.md +++ b/docs/train.md @@ -25,6 +25,8 @@ Training is (not-so-obviously) not dependent on: * for the old (`base.py`) implementation, further experimentation is required, but from what I remember the smaller models don't have speech emerge as fast, while the larger size models don't seem to benefit much. * for the new (`base_v2.py`) implementation, it seems that model size doesn't affect quality at all, at least in the primary phase of getting it to speak. * the "training progression" (how the loss/accuracy/grad norm curves look) are about the exact same between the "normal" (1024 dim, 12 layers, 12 heads) size and the "half" (512 dim, 12 layers, 8 heads) size, and presumably for the "double" size (1538 dim, 24 layers, 24 heads). + * the "half" size might actually be too small for it to have enough "capacity" to attend to all the speakers I'm training against. + * per E2/F5's paper, a size of 1024 dim, 4x FFN, 16 heads, 24 layers might be preferable? * it *probably* is necessary to have a larger model to better adhere to the reference clip, but experimentation is not at the point yet to verify this. * the audio codec, to an extent * for the old (`base.py`) implementation, EnCodec only seems to work well for it, as DAC might requires some hacks or patience for the higher RVQ levels to train, while `nvidia/audio-codec-44khz` requires an exotic training cirriculum, assumedly. @@ -40,7 +42,6 @@ A training paradigm that seems to work for me is to: * this also benefits from the model training on a variety of durations to avoid it overfitting for the last duration set trained against * optionally, you can sample based on speaker instead to balance out the speakers trained against, but this isn't all that necessary - Training under `float16` (+AMP) should be fairly simple, but it's practically required to use the `deepspeed` backend. * This is because `deepspeed` will automatically wrap the optimizer to handle training under `float16` and does some extra magic for stability. The `local` backend does do loss scaling, but not the extra steps. * Training under `bfloat16` does not have to worry about this, but I feel `bfloat16` training sessions don't have a specific training trait that `float16` does have, personally. @@ -72,7 +73,6 @@ The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster, * `APOLLO` needs more testing, but seemed adequate in cursory tests * `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True` * I honestly don't think it gives good enough results from curosry tests for this application -* `Adagrad` surprisingly seems to "fix" (for now) my problems with the loss / accuracy bouncing. ## Try Me diff --git a/vall_e/config.py b/vall_e/config.py index 4066677..c64880a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -281,6 +281,12 @@ class ModelExperimentalSettings: # logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder + audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training + # "auto" will pick best for codec + # "decreasing" will do the RVQ strat (prioritize lower levels) + # "normal" will do the FSQ strat (prioritize midrange) + # "equal" or "none" will set do no leveling + # list of floats to manually set # these technically should be as hyperparameters # performs token dropout to compensate for errors @@ -561,8 +567,9 @@ class DeepSpeed: optimizer: bool = True # use DeepSpeed optimizer wrapper amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently) - loss_scale_window: int = 100 - min_loss_scale: float = 8192.0 + loss_scale_window: int = 1000 + min_loss_scale: float = 32768.0 + loss_scale = 0.0 config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config @@ -614,9 +621,9 @@ class DeepSpeed: "fp16": { "enabled": cfg.trainer.weight_dtype.lower() == "float16", "auto_cast": True, # ??? - "loss_scale_window": self.loss_scale_window, # raise every 100 consecutive good steps - "min_loss_scale": self.min_loss_scale, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy - "loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0, + "loss_scale_window": self.loss_scale_window, + "min_loss_scale": self.min_loss_scale, + "loss_scale": self.loss_scale if cfg.trainer.scale_loss else 1.0, # use defined loss scale (defaults to 0, which is dynamic) if requested, or 1.0 (none) if not }, "bf16": { "enabled": cfg.trainer.weight_dtype.lower() == "bfloat16", diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 7b8915b..0cfc0ed 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -1036,7 +1036,7 @@ def example_usage(): if task == "stt": prom = [ task ] else: - task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len" + task = "tts" # if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len" texts.append( text ) proms.append( prom ) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 76b55ca..26486e3 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -300,7 +300,7 @@ class Base_V2(nn.Module): resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True predict_causally = config.experimental.predict_causally if config is not None else False monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False - audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ + audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto" logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True @@ -309,6 +309,27 @@ class Base_V2(nn.Module): n_langs = config.langs if config is not None else 2 n_tones = config.tones if config is not None else 1 + if audio_level_loss_factors == "auto": + audio_level_loss_factors = "normal" if n_audio_tokens == 1000 else "decreasing" + + if audio_level_loss_factors == "decreasing": + audio_level_loss_factors = [1.0 / (i + 1) for i in range(n_resp_levels)] + elif audio_level_loss_factors == "normal": + if n_resp_levels == 8: + audio_level_loss_factors = [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5] + else: + center = n_resp_levels // 2 + audio_level_loss_factors = [1.0 - abs(i - center) / n_resp_levels for i in range(n_resp_levels)] + + # to-do: proper cirriculum + # prioritizes midrange, maybe good for epoch 0? + # [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5] + + # deprioritizes midrange, good for epoch 1? + # [0.875, 0.75, 0.625, 0.5, 0.5, 0.625, 0.75, 0.875] + elif audio_level_loss_factors == "equal": + audio_level_loss_factors = [1.0 for _ in range(n_resp_levels)] + if attention_backend == "auto": attention_backend = "sdpa" @@ -320,18 +341,6 @@ class Base_V2(nn.Module): if attention_backend not in AVAILABLE_ATTENTIONS: raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}") - # to-do: deduce nemo better-er - if n_audio_tokens == 1000: - # assume midrage contains important details - center = n_resp_levels // 2 - audio_level_weights = [1.0 - abs(i - center) / n_resp_levels for i in range(n_resp_levels)] - # to-do: proper cirriculum - # prioritizes midrange, maybe good for epoch 0? - # [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5] - - # deprioritizes midrange, good for epoch 1? - # [0.875, 0.75, 0.625, 0.5, 0.5, 0.625, 0.75, 0.875] - self.training = training self.teaching = False self.config = config @@ -380,7 +389,7 @@ class Base_V2(nn.Module): self.masking_ratio = masking_ratio self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks - self.audio_level_weights = audio_level_weights + self.audio_level_loss_factors = audio_level_loss_factors self.logit_normalization = logit_normalization self.sep = nn.Parameter(torch.randn(d_model)) @@ -391,6 +400,7 @@ class Base_V2(nn.Module): self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None self.len_emb = ml.Embedding(11, d_model) + # to-do: un-autoregressivefy len inferencing, and have it trained parallel to normal training through a separate head or something self.audio_emb = None self.proms_emb = None @@ -867,7 +877,7 @@ class Base_V2(nn.Module): device = logits[0].device batch_size = len(logits) classifier_levels = self.get_input( inputs, "classifier_level" ) - level_weights = self.audio_level_weights + level_loss_factor = self.audio_level_loss_factors # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -876,6 +886,7 @@ class Base_V2(nn.Module): return input + k_lo, k_hi = 1, 20 def _calc_loss( logit, sequence, causal = True, level = None ): # filter tokens that exceed the vocab size sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) @@ -906,37 +917,38 @@ class Base_V2(nn.Module): sequence = sequence.reshape(-1) nll = None - acc_k1 = None + acc_k_lo = None if compute_hard_loss: reduction = 'mean' if not batched else 'none' - weight = level_weights[level] if level is not None and not batched else 1 - + weight = level_loss_factor[level] if level is not None and not batched else 1 nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight # manually weigh each level if batched: - nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device) + nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factor, device=device) if compute_acc: - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = 1, - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - acc_k1 = accuracy_metric( logit, sequence ) + if logit.shape[0] >= k_lo: + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = 1, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + acc_k_lo = accuracy_metric( logit, sequence ) - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = min(logit.shape[0], 80), - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - acc_k80 = accuracy_metric( logit, sequence ) + if logit.shape[0] >= k_hi: + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = 20, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + acc_k_hi = accuracy_metric( logit, sequence ) - return nll, acc_k1, acc_k80 + return nll, acc_k_lo, acc_k_hi for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] @@ -1022,7 +1034,7 @@ class Base_V2(nn.Module): continue if logits[batch_index].dim() < 3: - nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][start:end], token.long(), causal ) + nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][start:end], token.long(), causal ) elif not self.resp_parallel_training: # cringe way to deduce "requested" level level = quant_level @@ -1035,31 +1047,31 @@ class Base_V2(nn.Module): name = f'{name}[{level}]' sequence = token if token.dim() <= 1 else token[:, level] - nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) + nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level ) else: sequence = token.t() - nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) + nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal ) if nll is not None: nll = nll.mean() loss_key = f'{name}.nll' - acc_k1_key = f'{name}.acc[k=1]' - acc_k80_key = f'{name}.acc[k=80]' + acc_k_lo_key = f'{name}.acc[k={k_lo}]' + acc_k_hi_key = f'{name}.acc[k={k_hi}]' if nll is not None: if loss_key not in loss: loss[loss_key] = [] loss[loss_key].append( nll * loss_factor ) - if acc_k1 is not None: - if acc_k1_key not in stats: - stats[acc_k1_key] = [] - stats[acc_k1_key].append( acc_k1 ) + if acc_k_lo is not None: + if acc_k_lo_key not in stats: + stats[acc_k_lo_key] = [] + stats[acc_k_lo_key].append( acc_k_lo ) - if acc_k80 is not None: - if acc_k80_key not in stats: - stats[acc_k80_key] = [] - stats[acc_k80_key].append( acc_k80 ) + if acc_k_hi is not None: + if acc_k_hi_key not in stats: + stats[acc_k_hi_key] = [] + stats[acc_k_hi_key].append( acc_k_hi ) # add to list else: target.append( token ) @@ -1069,7 +1081,7 @@ class Base_V2(nn.Module): if not self.config.loss_factors: if logits[batch_index].dim() < 3: sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) - nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index], sequence, causal ) + nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index], sequence, causal ) elif not self.resp_parallel_training: # cringe way to deduce "requested" level level = 0 @@ -1080,45 +1092,45 @@ class Base_V2(nn.Module): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) + nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level], sequence.long(), causal, level ) else: nlls = [] - acc_k1s = [] - acc_k80s = [] + acc_k_los = [] + acc_k_his = [] for level, logit in enumerate( logits[batch_index] ): sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ] sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) ) - nll, acc_k1, acc_k80 = _calc_loss( logit, sequence, causal, level ) + nll, acc_k_lo, acc_k_hi = _calc_loss( logit, sequence, causal, level ) if nll: nlls.append( nll ) - if acc_k1: - acc_k1s.append( acc_k1 ) - if acc_k80: - acc_k80s.append( acc_k80 ) + if acc_k_lo: + acc_k_los.append( acc_k_lo ) + if acc_k_hi: + acc_k_his.append( acc_k_hi ) if nlls: nll = sum(nlls) / len(nlls) - if acc_k1s: - acc_k1 = sum(acc_k1s) / len(acc_k1s) - if acc_k80s: - acc_k80 = sum(acc_k80s) / len(acc_k80s) + if acc_k_los: + acc_k_lo = sum(acc_k_los) / len(acc_k_los) + if acc_k_his: + acc_k_hi = sum(acc_k_his) / len(acc_k_his) if nll is not None: if 'nll' not in loss: loss['nll'] = [] loss["nll"].append( nll ) - if acc_k1 is not None: - if 'acc[k=1]' not in stats: - stats['acc[k=1]'] = [] - stats["acc[k=1]"].append( acc_k1 ) + if acc_k_lo is not None: + if f'acc[k={k_lo}]' not in stats: + stats[f'acc[k={k_lo}]'] = [] + stats[f"acc[k={k_lo}]"].append( acc_k_lo ) - if acc_k80 is not None: - if 'acc[k=80]' not in stats: - stats['acc[k=80]'] = [] - stats["acc[k=80]"].append( acc_k80 ) + if acc_k_hi is not None: + if f'acc[k={k_hi}]' not in stats: + stats[f'acc[k={k_hi}]'] = [] + stats[f"acc[k={k_hi}]"].append( acc_k_hi ) # average loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }