is this my last cope (falling back to explicit duration prediction, as this regression just won't go away) (also the smaller model was lobotomized because of my ROCm setup having a botched SDPA for who knows why)

This commit is contained in:
mrq 2025-04-02 17:01:24 -05:00
parent 7a0956863d
commit 0e995dbf2c
6 changed files with 56 additions and 28 deletions

View File

@ -56,7 +56,7 @@ While the model *can* still be trained as a hybrid AR/NAR, the reference `nemo-*
Like the previous implementation, this model can operate entirely non-autoregressively (and with non-causal attention) as a masked transformer. The demasking inference loop is the same as the previous implementation, where each demasking step can mask off an entire timestep on the sum of the logit scores, or independently (where each level has its own mask).
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.
Unlike the previous implementation, ~~duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.~~ duration prediction is trained non-autoregressively by taking a page from my image classification endeavors. By having the logit be flattened then reshaped, the need to autoregressively decode for the duration or decode the last N tokens for the duration is not necessary anymore.
#### Attention
@ -66,18 +66,20 @@ Previously, a full non-causal attention mask was employed, allowing for every to
This new implementation aims to restrict each segment from attending to future segments. In other words, the input text does not need to attend to the audio tokens, while the reference audio does not need to attend to the output.
* *Technically*, the reference audio doesn't need to attend to the input text, but it could allow for the model to explicitly map phonemes to the reference prompt.
* Unfortunately, this does not seem to work for the `nemo-smaller` model
* I'm not too sure why this is the case, but I suppose it's just how that model's weights progressed, as the `nemo-larger` model seems fine.
Additionally, sliding window attention is supported in this implementation, but has shown big regressions when performing additional training on existing weights.
Additionally, sliding window attention is supported in this implementation. but further experimentation is required.
* This should also allow the model to decouple from requiring a strict duration window for output, in theory.
* The fundamental principle behind this is that audio shouldn't be *that* directly dependent on an utterance X seconds in the past/future, so a sliding window is beneficial. However, I imagine the theory on why this doesn't work so well is that the model has established a non-trivial dependency on the entire utterance.
* I imagine in a broader sense, it can ensure coherency by ensuring an utterance is delivered in a similar way, or the model derives a "speaker" from utilizing the existing utterance tokens.
* A fresh model *could* have no issues, as it wouldn't be enough of a detriment.
* An existing model *could* be coerced with enough time, but I am not patient enough of a man to wait.
This implementation could utilize a causal attention mask, but both prior "testing" (in loose quotes, as it was due to an oversight) in the previous implementation and careless testing with this implementation shows that it's also a detriment to the model.
* Like the above, I imagine a fresh model *could* resolve this issue.
Lastly, partial attention allows for some clever tricks to train additional things in parallel, such as duration prediction.
* previously, this was naively assigned to computing loss against the duration at the last separator of a sequence (the one before the input prompt and output audio)
* however, a regression seemed to have caused this quirk to disappear, requiring falling back to explicit duration training
* a solution is to properly train this by injecting the "len" predictor token anywhere in the prompt, but take extra care in the 4D attention mask to only allow that token to attend to the input, and not have any other token attend to it.
### Pure AR
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each step samples *all* codebooks at one code-frame.
@ -124,7 +126,7 @@ len_loss_factor: 0.0001 # start with the default for a while to not let duration
noncausal_masks: True # creates non-causal masks
resp_parallel_training: True # trains all codebook levels in parallel
len_parallel_training: True # trains length duration alongside normal training
len_parallel_training: False # trains length duration alongside normal training
cfg_cond_dropout_p: 0.02 # was originally 0.3, but I think it's too much after a while
cfg_prom_dropout_p: 0.01 # was originally 0.2
@ -140,10 +142,10 @@ These settings should be avoided:
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
* `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end
* however, it seems this is a detriment to the model, I imagine because the model could rely on how something sounds earlier on, even if there shouldn't be a direct causal relationship
* this could be something that might need to be trained from the very beginning rather than early on, but training existing models does not seem to fare well
* `nemo-smaller-llama-8` seemed to have degraded far more than `nemo-larger-llama-8` did. I suppose the head count / size might matter.
* this could also have been caused by a regression in the code due to dumb Python aliasing behaviors
* however, it's possible this is a detriment itself, but further experimentation is needed
* `len_parallel_training`: this uses a clever quirk with how attention works to train duration prediction alongside normal TTS tasks
* however, it seems there's a regression that caused this to stop working consistently
* disabling this falls back to explicitly training a `len` task (like the old implementation)
## Benefits and Caveats

View File

@ -368,7 +368,7 @@ class Model:
return [ self ] if not name or self.name == name else []
def loss_factor(self, k):
return self.loss_factors.get(k, 0.0)
return self.loss_factors.get(k, 1.0)
@property
def max_levels(self):

View File

@ -24,24 +24,28 @@ try:
from .codecs.encodec import *
except Exception as e:
cfg.inference.use_encodec = False
raise e
_logger.warning(str(e))
try:
from .codecs.vocos import *
except Exception as e:
cfg.inference.use_vocos = False
raise e
_logger.warning(str(e))
try:
from .codecs.dac import *
except Exception as e:
cfg.inference.use_dac = False
#raise e
_logger.warning(str(e))
try:
from .codecs.nemo import *
except Exception as e:
cfg.inference.use_nemo = False
raise e
_logger.warning(str(e))
@cache

View File

@ -61,6 +61,11 @@ class TTS():
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
# fallback to encodec if no vocos
if cfg.audio_backend == "vocos" and not cfg.inference.use_vocos:
_logger.warning("Vocos requested but not available, falling back to Encodec...")
cfg.set_audio_backend(cfg.audio_backend)
if amp is None:
amp = cfg.inference.amp
if dtype is None or dtype == "auto":

View File

@ -943,7 +943,7 @@ def example_usage():
if task == "stt":
prom = [ task ]
else:
task = "tts" # if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
texts.append( text )
proms.append( prom )

View File

@ -100,7 +100,7 @@ class FiniteAudioEncoder(nn.Module):
if not d_model:
d_model = token_dim
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
if use_ffn:
@ -405,7 +405,8 @@ class Base_V2(nn.Module):
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
self.len_emb = None # ml.Embedding(11, d_model)
self.len_emb = nn.Parameter(torch.randn(d_model)) # ugh
self.audio_emb = None
self.proms_emb = None
@ -640,6 +641,11 @@ class Base_V2(nn.Module):
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert len marker
if resps_list is not None:
inputs[i].append( ( "len", torch.tensor([resps_list[i].shape[0]]) ) )
else:
inputs[i].append( ( "len", torch.tensor([0]) ) )
inputs[i].append( ("classifier_level", "len") )
# Speech-to-Text prediction task
@ -758,18 +764,14 @@ class Base_V2(nn.Module):
elif name == "timestep" and self.time_emb is not None:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
# singleton marker
embedding = self.len_emb[None]
else:
# should probably raise an exception so things aren't processed silently
continue
batch.append(embedding)
# needed, cringe
if task_type == "len":
#batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
x_list.append( _join( batch, self.sep ) )
return x_list
@ -889,7 +891,7 @@ class Base_V2(nn.Module):
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
causal = True
causal = False
task_type = "tts"
dropout_mask = None
classifier_level = None
@ -917,6 +919,7 @@ class Base_V2(nn.Module):
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
@ -932,6 +935,9 @@ class Base_V2(nn.Module):
# mask found, apply it
if dropout_mask is not None:
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
elif name == "len":
size = input[0].item()
token = torch.tensor([ int(i) for i in str( size ).zfill(5) ], device=device, dtype=torch.int64)
# not a special input, inject as-is
else:
token = input
@ -956,9 +962,9 @@ class Base_V2(nn.Module):
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
logit = logits[batch_index][start:end]
logit = logits[batch_index][start:end]
"""
if self.logit_normalization:
logit = logit_normalization( logit, self.logit_normalization )
@ -968,9 +974,13 @@ class Base_V2(nn.Module):
l = self.causal_size
loss_targets.append( token[l:].long() ) # shift the target so that token n...
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
elif name == "len":
loss_targets.append( token.long() )
loss_logits.append( logit.squeeze(0) )
else:
loss_targets.append( token.long() )
loss_logits.append( logit )
loss_factors.append( loss_factor )
loss_names.append( name )
else:
@ -1159,6 +1169,8 @@ class Base_V2(nn.Module):
lens[1] = input.shape[0] + 1
elif name == "tone":
lens[1] += 2
elif name == "len":
lens[2] = 2
elif name == "resp":
lens[2] = input.shape[0]
@ -1179,8 +1191,8 @@ class Base_V2(nn.Module):
hidden_states = output.hidden_states
logits = self.audio_decoder( output.logits )
"""
# logits = self.audio_decoder( output.logits )
logits = [ logit for logit in output.logits ]
logits_aux = None
@ -1213,11 +1225,16 @@ class Base_V2(nn.Module):
decoders_logits = head( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
"""
# Remove padding
logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ]
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level != "len":
continue
logits[batch_index] = logits[batch_index].view(-1, 5, 10)
if not training:
loss = None
stats = None