From b0dba9db075a8c5a5f1ab7e1344248e934567531 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 17 Mar 2025 21:46:50 -0500 Subject: [PATCH] this may bite me in the ass --- docs/models_v2.md | 8 ++++++-- vall_e/config.py | 6 ++++-- vall_e/models/ar_nar_v2.py | 10 ++++++++-- vall_e/models/base_v2.py | 5 +++++ 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/docs/models_v2.md b/docs/models_v2.md index 536ed03..618ed5c 100644 --- a/docs/models_v2.md +++ b/docs/models_v2.md @@ -69,6 +69,10 @@ To be evaluated, as additional training time is required, despite progression se At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it. -Additionally, this implementation paves the way for live decoding of the audio under the autoregressive mode (if trained for it). +Additionally, this implementation paves the way a ton of neat features, such as: +* live playback through autoregressive inferencing, as all codebooks are predicted for each step + * could also be "mocked" by doing NAR-len demasking in chunks +* inherent audio upscaling, as the model is trained on a 44KHz codec -However, I'm not sure if the additional complexity justifies it. \ No newline at end of file +However, I'm not sure if the additional complexity justifies it. +* the current hurdle is that speaker similarity is ***dismal*** \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index d6e6252..951214a 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -255,10 +255,11 @@ class Dataset: return self.duration_range[1] # collection of experimental variables that should not be tampered with unless you know what you're doing +# to-do: clean this up @dataclass() class ModelExperimentalSettings: - hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly - interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) + hf: bool = False # unused, strictly utilizes a HF model and handles converting input IDs / outputs accordingly + interleave: bool = False # unused, use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal) split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom) audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level # a model trained not summing audio embeddings *can* have this enabled without any apparent issues @@ -297,6 +298,7 @@ class ModelExperimentalSettings: # * the model wouldn't also need to learn when to predict the token in place len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong) len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16 + parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking # logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index df7ddbc..2a9df49 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -239,13 +239,13 @@ class AR_NAR_V2(Base_V2): # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough temperature = sampling_kwargs.pop("temperature", 0.0) - minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5) + minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 0) # 2.5) # this really helps keep audio coherent so far cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength) cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", True) + remasking = sampling_kwargs.get("remasking", False) max_steps = math.floor(max_steps * (end_noise - start_noise)) # to specify the initial mask used @@ -648,6 +648,12 @@ class AR_NAR_V2(Base_V2): # is NAR if (len_list is not None or resps_list is not None) and phns_list is not None: + # to-do: verify this actually does return the input resps if theyre already filled + """ + if resps_list is not None: + return resps_list + """ + return self.forward_nar_masked( task_list=task_list, diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index fc78d0f..bb12eb7 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -275,6 +275,7 @@ class Base_V2(nn.Module): logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True + parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0 n_vocab = 256 n_tasks = config.tasks if config is not None else 8 @@ -366,6 +367,7 @@ class Base_V2(nn.Module): self.len_loss_factor = len_loss_factor self.logit_normalization = False # this actually kills the model's demasking capabilities self.use_segmented_attention_mask = use_segmented_attention_mask + self.parallel_attention_mask_dropout = parallel_attention_mask_dropout self.sep = nn.Parameter(torch.randn(d_model)) @@ -1110,6 +1112,9 @@ class Base_V2(nn.Module): # right now limit to new versions because I need to retrain the model for noncausal masks... is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] + if self.parallel_attention_mask_dropout > 0: + is_causal = [ True if random.random() < parallel_attention_mask_dropout else m for m in is_causal ] + # create special masks # to-do, create it if mixed (although I expect this model to be purely non-causal) aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)