reintroduced sampler_type = speaker because I think this might salvage the nemo model to have better speaker similarities

This commit is contained in:
mrq 2025-04-03 19:01:10 -05:00
parent caad99ab78
commit 2e93438867
2 changed files with 29 additions and 26 deletions

View File

@ -56,29 +56,19 @@ While the model *can* still be trained as a hybrid AR/NAR, the reference `nemo-*
Like the previous implementation, this model can operate entirely non-autoregressively (and with non-causal attention) as a masked transformer. The demasking inference loop is the same as the previous implementation, where each demasking step can mask off an entire timestep on the sum of the logit scores, or independently (where each level has its own mask). Like the previous implementation, this model can operate entirely non-autoregressively (and with non-causal attention) as a masked transformer. The demasking inference loop is the same as the previous implementation, where each demasking step can mask off an entire timestep on the sum of the logit scores, or independently (where each level has its own mask).
Unlike the previous implementation, ~~duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.~~ duration prediction is trained non-autoregressively by taking a page from my image classification endeavors. By having the logit be flattened then reshaped, the need to autoregressively decode for the duration or decode the last N tokens for the duration is not necessary anymore. Quasi-similarly to the previous implementation, duration prediction is trained through an explicit task, but unlike the previous implementation, this does not need to be autoregressively inferenced. By making use of my findings with a classifier-based OCR model, duration prediction can be done with one "token" and reshaping it into several digits for the final logits.
* Instead of a discrete, logit based output, it's possible to instead output a raw float to correspond to the seconds and train using `mse_loss` (or maybe `kl_div`), but experimentation shows that it's quite a pickle to train, even with weighing its loss down considerably.
* This *could* be trained in parallel using clever tricks with the attention mask, but a regression in the model/code suggests it's not worth wrangling for this feature.
#### Attention #### Attention
Unlike the previous implementation, attention needs to be paid towards the attention mask used. As suggested from the prior implementation, attention needs to be paid to the attention mask used, as it's quite easy to have the model degrade from a silent problem.
Previously, a full non-causal attention mask was employed, allowing for every token to attend to every other token. This is *sort of* fine, but is unneccessary, as some segments do not need to attend to other segments. A naive, fully non-causal attention just works, and while it seems a little unintuitive to have the input prompt attend to the output, I imagine the model does some bookkeeping in the input portion of the prompt. By constraining the input to not attend to the output, the model also grows constrained in its performance.
This new implementation aims to restrict each segment from attending to future segments. In other words, the input text does not need to attend to the audio tokens, while the reference audio does not need to attend to the output. Prior (failed) experimentation with exotic masks include:
* *Technically*, the reference audio doesn't need to attend to the input text, but it could allow for the model to explicitly map phonemes to the reference prompt. * a segmented mask, where each segment can only attend to itself and the prior segment.
* sliding attention, where each segment can only attend to its own window instead of its entire segment.
Additionally, sliding window attention is supported in this implementation. but further experimentation is required.
* This should also allow the model to decouple from requiring a strict duration window for output, in theory.
* The fundamental principle behind this is that audio shouldn't be *that* directly dependent on an utterance X seconds in the past/future, so a sliding window is beneficial. However, I imagine the theory on why this doesn't work so well is that the model has established a non-trivial dependency on the entire utterance.
* I imagine in a broader sense, it can ensure coherency by ensuring an utterance is delivered in a similar way, or the model derives a "speaker" from utilizing the existing utterance tokens.
This implementation could utilize a causal attention mask, but both prior "testing" (in loose quotes, as it was due to an oversight) in the previous implementation and careless testing with this implementation shows that it's also a detriment to the model.
* Like the above, I imagine a fresh model *could* resolve this issue.
Lastly, partial attention allows for some clever tricks to train additional things in parallel, such as duration prediction.
* previously, this was naively assigned to computing loss against the duration at the last separator of a sequence (the one before the input prompt and output audio)
* however, a regression seemed to have caused this quirk to disappear, requiring falling back to explicit duration training
* a solution is to properly train this by injecting the "len" predictor token anywhere in the prompt, but take extra care in the 4D attention mask to only allow that token to attend to the input, and not have any other token attend to it.
### Pure AR ### Pure AR
@ -120,7 +110,7 @@ audio_level_loss_factors: "normal" # distribution of loss weights per codebook (
masking_train_p: 1.0 # pure AR masking_train_p: 1.0 # pure AR
masking_ratio: 0.8 # fixed mask ratio proves to be better masking_ratio: 0.8 # fixed mask ratio proves to be better
ignore_inputs_for_loss: True # False is not implemented ignore_inputs_for_loss: True # False is not implemented
use_segmented_attention_mask: True # restricts each section within its own section + prior section (in other words, does not let the text/prom see further into the future outside of its segment), also enables parallel duration training use_segmented_attention_mask: False #
use_streamlined_calc_loss: True # False has no effect now use_streamlined_calc_loss: True # False has no effect now
len_loss_factor: 0.0001 # start with the default for a while to not let duration training overpower the model, then gradually increase this (but this may only be required when introducing duration training on existing weights) len_loss_factor: 0.0001 # start with the default for a while to not let duration training overpower the model, then gradually increase this (but this may only be required when introducing duration training on existing weights)
@ -141,6 +131,7 @@ These settings should be avoided:
* `parallel_attention_mask_dropout`: this governs the rate of flipping to a causal (triangle) mask for training * `parallel_attention_mask_dropout`: this governs the rate of flipping to a causal (triangle) mask for training
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this) * there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with. * the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
* `use_segmented_attention_mask`: training metrics suggests this is fine, but real world usage shows it's not.
* `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end * `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end
* however, it's possible this is a detriment itself, but further experimentation is needed * however, it's possible this is a detriment itself, but further experimentation is needed
* `len_parallel_training`: this uses a clever quirk with how attention works to train duration prediction alongside normal TTS tasks * `len_parallel_training`: this uses a clever quirk with how attention works to train duration prediction alongside normal TTS tasks

View File

@ -1207,9 +1207,15 @@ class Dataset(_Dataset):
bos_id, space_id, eos_id = self.empty_text bos_id, space_id, eos_id = self.empty_text
speaker_id, utterance_id = self.paths[index] if self.sampler_type == "speaker":
speaker_name = self.speakers[speaker_id] speaker_id = index
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id] speaker_name = self.speakers[speaker_id]
utterance_name = random.choice( list(self.metadata[speaker_name].keys()) ) # random.choice(self.metadata[speaker_name])
else:
speaker_id, utterance_id = self.paths[index]
speaker_name = self.speakers[speaker_id]
utterance_name = list(self.metadata[speaker_name].keys())[utterance_id]
path = cfg.data_dir / speaker_name / utterance_name path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
@ -1916,10 +1922,16 @@ if __name__ == "__main__":
symmap = get_phone_symmap() symmap = get_phone_symmap()
for index in tqdm(range(len( dataset )), desc="Processing dataset..."): for index in tqdm(range(len( dataset )), desc="Processing dataset..."):
speaker_id, utterance_id = dataset.paths[index] if dataset.sampler_type == "speaker":
speaker_name = dataset.speakers[speaker_id] speaker_id = index
speaker_keys = list(dataset.metadata[speaker_name].keys()) speaker_name = dataset.speakers[speaker_id]
utterance_name = speaker_keys[utterance_id] utterance_name = random.choice( list(dataset.metadata[speaker_name].keys()) ) # random.choice(dataset.metadata[speaker_name])
else:
speaker_id, utterance_id = dataset.paths[index]
speaker_name = dataset.speakers[speaker_id]
speaker_keys = list(dataset.metadata[speaker_name].keys())
utterance_name = speaker_keys[utterance_id]
path = cfg.data_dir / speaker_name / utterance_name path = cfg.data_dir / speaker_name / utterance_name
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5: