this may bite me in the ass

This commit is contained in:
mrq 2025-03-17 21:46:50 -05:00
parent 2dfef693c4
commit b0dba9db07
4 changed files with 23 additions and 6 deletions

View File

@ -69,6 +69,10 @@ To be evaluated, as additional training time is required, despite progression se
At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it.
Additionally, this implementation paves the way for live decoding of the audio under the autoregressive mode (if trained for it).
Additionally, this implementation paves the way a ton of neat features, such as:
* live playback through autoregressive inferencing, as all codebooks are predicted for each step
* could also be "mocked" by doing NAR-len demasking in chunks
* inherent audio upscaling, as the model is trained on a 44KHz codec
However, I'm not sure if the additional complexity justifies it.
However, I'm not sure if the additional complexity justifies it.
* the current hurdle is that speaker similarity is ***dismal***

View File

@ -255,10 +255,11 @@ class Dataset:
return self.duration_range[1]
# collection of experimental variables that should not be tampered with unless you know what you're doing
# to-do: clean this up
@dataclass()
class ModelExperimentalSettings:
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
hf: bool = False # unused, strictly utilizes a HF model and handles converting input IDs / outputs accordingly
interleave: bool = False # unused, use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
# a model trained not summing audio embeddings *can* have this enabled without any apparent issues
@ -297,6 +298,7 @@ class ModelExperimentalSettings:
# * the model wouldn't also need to learn when to predict the token in place
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298

View File

@ -239,13 +239,13 @@ class AR_NAR_V2(Base_V2):
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
temperature = sampling_kwargs.pop("temperature", 0.0)
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 0) # 2.5)
# this really helps keep audio coherent so far
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", True)
remasking = sampling_kwargs.get("remasking", False)
max_steps = math.floor(max_steps * (end_noise - start_noise))
# to specify the initial mask used
@ -648,6 +648,12 @@ class AR_NAR_V2(Base_V2):
# is NAR
if (len_list is not None or resps_list is not None) and phns_list is not None:
# to-do: verify this actually does return the input resps if theyre already filled
"""
if resps_list is not None:
return resps_list
"""
return self.forward_nar_masked(
task_list=task_list,

View File

@ -275,6 +275,7 @@ class Base_V2(nn.Module):
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
n_vocab = 256
n_tasks = config.tasks if config is not None else 8
@ -366,6 +367,7 @@ class Base_V2(nn.Module):
self.len_loss_factor = len_loss_factor
self.logit_normalization = False # this actually kills the model's demasking capabilities
self.use_segmented_attention_mask = use_segmented_attention_mask
self.parallel_attention_mask_dropout = parallel_attention_mask_dropout
self.sep = nn.Parameter(torch.randn(d_model))
@ -1110,6 +1112,9 @@ class Base_V2(nn.Module):
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
if self.parallel_attention_mask_dropout > 0:
is_causal = [ True if random.random() < parallel_attention_mask_dropout else m for m in is_causal ]
# create special masks
# to-do, create it if mixed (although I expect this model to be purely non-causal)
aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)