config-ify the len_loss_factor
This commit is contained in:
parent
ca8cc15271
commit
9cfbf94b1c
|
@ -5,6 +5,7 @@ This section aims to document the `_v2` class of models. Documentation here migh
|
|||
Unlike the original, this implementation strives to operate on *all* codebooks at once with a full 44KHz bandwidth, rather than requiring the model to operate on one codebook level at a time at 24KHz audio.
|
||||
|
||||
This model *might* not scale up well, as the `nemo-smaller-44khz-llama-8` brand seems to perform at a similar quality to `nemo-larger-44khz-llama-8`. While the latter had speech emerge much quicker than the former, both seem to have a problem with consistently working on various speakers unlike the previous series of models.
|
||||
* The current issue seems to be it poorly following the prompted speaker, which if I remember right, required quite a few epochs to resolve in the base `ar+nar-len-llama-8` model.
|
||||
|
||||
## Audio Codecs
|
||||
|
||||
|
|
|
@ -296,6 +296,7 @@ class ModelExperimentalSettings:
|
|||
# * NAR-demask would semi-doubly train for AR
|
||||
# * the model wouldn't also need to learn when to predict the token in place
|
||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
||||
|
||||
#
|
||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||
|
|
|
@ -271,6 +271,7 @@ class Base_V2(nn.Module):
|
|||
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||
len_loss_factor = config.experimental.len_loss_factor if config is not None else 0.00001
|
||||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
|
@ -362,6 +363,7 @@ class Base_V2(nn.Module):
|
|||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
self.audio_level_loss_factors = audio_level_loss_factors
|
||||
self.len_loss_factor = len_loss_factor
|
||||
self.logit_normalization = False # this actually kills the model's demasking capabilities
|
||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||
|
||||
|
@ -1047,7 +1049,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
# check if len logits are provided
|
||||
if logits_aux is not None:
|
||||
len_factor = 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
|
||||
len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
|
||||
aux_loss_logit = torch.cat( logits_aux )
|
||||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
|
Loading…
Reference in New Issue
Block a user