config-ify the len_loss_factor

This commit is contained in:
mrq 2025-03-14 20:30:48 -05:00
parent ca8cc15271
commit 9cfbf94b1c
3 changed files with 5 additions and 1 deletions

View File

@ -5,6 +5,7 @@ This section aims to document the `_v2` class of models. Documentation here migh
Unlike the original, this implementation strives to operate on *all* codebooks at once with a full 44KHz bandwidth, rather than requiring the model to operate on one codebook level at a time at 24KHz audio.
This model *might* not scale up well, as the `nemo-smaller-44khz-llama-8` brand seems to perform at a similar quality to `nemo-larger-44khz-llama-8`. While the latter had speech emerge much quicker than the former, both seem to have a problem with consistently working on various speakers unlike the previous series of models.
* The current issue seems to be it poorly following the prompted speaker, which if I remember right, required quite a few epochs to resolve in the base `ar+nar-len-llama-8` model.
## Audio Codecs

View File

@ -296,6 +296,7 @@ class ModelExperimentalSettings:
# * NAR-demask would semi-doubly train for AR
# * the model wouldn't also need to learn when to predict the token in place
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298

View File

@ -271,6 +271,7 @@ class Base_V2(nn.Module):
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
@ -362,6 +363,7 @@ class Base_V2(nn.Module):
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.audio_level_loss_factors = audio_level_loss_factors
self.len_loss_factor = len_loss_factor
self.logit_normalization = False # this actually kills the model's demasking capabilities
self.use_segmented_attention_mask = use_segmented_attention_mask
@ -1047,7 +1049,7 @@ class Base_V2(nn.Module):
# check if len logits are provided
if logits_aux is not None:
len_factor = 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
aux_loss_logit = torch.cat( logits_aux )
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor