this may bite me in the ass
This commit is contained in:
parent
2dfef693c4
commit
b0dba9db07
|
@ -69,6 +69,10 @@ To be evaluated, as additional training time is required, despite progression se
|
|||
|
||||
At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it.
|
||||
|
||||
Additionally, this implementation paves the way for live decoding of the audio under the autoregressive mode (if trained for it).
|
||||
Additionally, this implementation paves the way a ton of neat features, such as:
|
||||
* live playback through autoregressive inferencing, as all codebooks are predicted for each step
|
||||
* could also be "mocked" by doing NAR-len demasking in chunks
|
||||
* inherent audio upscaling, as the model is trained on a 44KHz codec
|
||||
|
||||
However, I'm not sure if the additional complexity justifies it.
|
||||
However, I'm not sure if the additional complexity justifies it.
|
||||
* the current hurdle is that speaker similarity is ***dismal***
|
|
@ -255,10 +255,11 @@ class Dataset:
|
|||
return self.duration_range[1]
|
||||
|
||||
# collection of experimental variables that should not be tampered with unless you know what you're doing
|
||||
# to-do: clean this up
|
||||
@dataclass()
|
||||
class ModelExperimentalSettings:
|
||||
hf: bool = False # strictly utilizes a HF model and handles converting input IDs / outputs accordingly
|
||||
interleave: bool = False # use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||
hf: bool = False # unused, strictly utilizes a HF model and handles converting input IDs / outputs accordingly
|
||||
interleave: bool = False # unused, use an interleaved AR rather than a split AR + NAR (worse performance and results due to everything being causal)
|
||||
split_classifiers: bool = False # each RVQ level gets its own classifier / output proj / LM head rather than sharing one for all RVQ levels (to-do: also split for text/prom)
|
||||
audio_embedding_sums: bool = False # whether each pass uses the previous RVQ codes or only the current level
|
||||
# a model trained not summing audio embeddings *can* have this enabled without any apparent issues
|
||||
|
@ -297,6 +298,7 @@ class ModelExperimentalSettings:
|
|||
# * the model wouldn't also need to learn when to predict the token in place
|
||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||
len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16
|
||||
parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking
|
||||
|
||||
#
|
||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||
|
|
|
@ -239,13 +239,13 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
# greedy sampling is very, very much preferred, but using greedy logit scores later helps enough
|
||||
temperature = sampling_kwargs.pop("temperature", 0.0)
|
||||
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5)
|
||||
minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 0) # 2.5)
|
||||
# this really helps keep audio coherent so far
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
remasking = sampling_kwargs.get("remasking", False)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
# to specify the initial mask used
|
||||
|
@ -648,6 +648,12 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and phns_list is not None:
|
||||
# to-do: verify this actually does return the input resps if theyre already filled
|
||||
"""
|
||||
if resps_list is not None:
|
||||
return resps_list
|
||||
"""
|
||||
|
||||
return self.forward_nar_masked(
|
||||
task_list=task_list,
|
||||
|
||||
|
|
|
@ -275,6 +275,7 @@ class Base_V2(nn.Module):
|
|||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
parallel_attention_mask_dropout = config.experimental.parallel_attention_mask_dropout if config is not None else 0.0
|
||||
|
||||
n_vocab = 256
|
||||
n_tasks = config.tasks if config is not None else 8
|
||||
|
@ -366,6 +367,7 @@ class Base_V2(nn.Module):
|
|||
self.len_loss_factor = len_loss_factor
|
||||
self.logit_normalization = False # this actually kills the model's demasking capabilities
|
||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||
self.parallel_attention_mask_dropout = parallel_attention_mask_dropout
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -1110,6 +1112,9 @@ class Base_V2(nn.Module):
|
|||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
||||
if self.parallel_attention_mask_dropout > 0:
|
||||
is_causal = [ True if random.random() < parallel_attention_mask_dropout else m for m in is_causal ]
|
||||
|
||||
# create special masks
|
||||
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
||||
aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)
|
||||
|
|
Loading…
Reference in New Issue
Block a user