From 9cfbf94b1c8a591481a355ffbc1e7b1f42eff5e0 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 14 Mar 2025 20:30:48 -0500 Subject: [PATCH] config-ify the len_loss_factor --- docs/models_v2.md | 1 + vall_e/config.py | 1 + vall_e/models/base_v2.py | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/models_v2.md b/docs/models_v2.md index 8f28297..9a9ad10 100644 --- a/docs/models_v2.md +++ b/docs/models_v2.md @@ -5,6 +5,7 @@ This section aims to document the `_v2` class of models. Documentation here migh Unlike the original, this implementation strives to operate on *all* codebooks at once with a full 44KHz bandwidth, rather than requiring the model to operate on one codebook level at a time at 24KHz audio. This model *might* not scale up well, as the `nemo-smaller-44khz-llama-8` brand seems to perform at a similar quality to `nemo-larger-44khz-llama-8`. While the latter had speech emerge much quicker than the former, both seem to have a problem with consistently working on various speakers unlike the previous series of models. +* The current issue seems to be it poorly following the prompted speaker, which if I remember right, required quite a few epochs to resolve in the base `ar+nar-len-llama-8` model. ## Audio Codecs diff --git a/vall_e/config.py b/vall_e/config.py index a622aa0..4ded6b4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -296,6 +296,7 @@ class ModelExperimentalSettings: # * NAR-demask would semi-doubly train for AR # * the model wouldn't also need to learn when to predict the token in place len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong) + len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16 # logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 770907c..fc78d0f 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -271,6 +271,7 @@ class Base_V2(nn.Module): predict_causally = config.experimental.predict_causally if config is not None else False monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto" + len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001 logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True @@ -362,6 +363,7 @@ class Base_V2(nn.Module): self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks self.audio_level_loss_factors = audio_level_loss_factors + self.len_loss_factor = len_loss_factor self.logit_normalization = False # this actually kills the model's demasking capabilities self.use_segmented_attention_mask = use_segmented_attention_mask @@ -1047,7 +1049,7 @@ class Base_V2(nn.Module): # check if len logits are provided if logits_aux is not None: - len_factor = 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses) + len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses) aux_loss_logit = torch.cat( logits_aux ) #aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 ) #loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor