is this my last cope (falling back to explicit duration prediction, as this regression just won't go away) (also the smaller model was lobotomized because of my ROCm setup having a botched SDPA for who knows why)
This commit is contained in:
parent
7a0956863d
commit
0e995dbf2c
|
@ -56,7 +56,7 @@ While the model *can* still be trained as a hybrid AR/NAR, the reference `nemo-*
|
|||
|
||||
Like the previous implementation, this model can operate entirely non-autoregressively (and with non-causal attention) as a masked transformer. The demasking inference loop is the same as the previous implementation, where each demasking step can mask off an entire timestep on the sum of the logit scores, or independently (where each level has its own mask).
|
||||
|
||||
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.
|
||||
Unlike the previous implementation, ~~duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.~~ duration prediction is trained non-autoregressively by taking a page from my image classification endeavors. By having the logit be flattened then reshaped, the need to autoregressively decode for the duration or decode the last N tokens for the duration is not necessary anymore.
|
||||
|
||||
#### Attention
|
||||
|
||||
|
@ -66,18 +66,20 @@ Previously, a full non-causal attention mask was employed, allowing for every to
|
|||
|
||||
This new implementation aims to restrict each segment from attending to future segments. In other words, the input text does not need to attend to the audio tokens, while the reference audio does not need to attend to the output.
|
||||
* *Technically*, the reference audio doesn't need to attend to the input text, but it could allow for the model to explicitly map phonemes to the reference prompt.
|
||||
* Unfortunately, this does not seem to work for the `nemo-smaller` model
|
||||
* I'm not too sure why this is the case, but I suppose it's just how that model's weights progressed, as the `nemo-larger` model seems fine.
|
||||
|
||||
Additionally, sliding window attention is supported in this implementation, but has shown big regressions when performing additional training on existing weights.
|
||||
Additionally, sliding window attention is supported in this implementation. but further experimentation is required.
|
||||
* This should also allow the model to decouple from requiring a strict duration window for output, in theory.
|
||||
* The fundamental principle behind this is that audio shouldn't be *that* directly dependent on an utterance X seconds in the past/future, so a sliding window is beneficial. However, I imagine the theory on why this doesn't work so well is that the model has established a non-trivial dependency on the entire utterance.
|
||||
* I imagine in a broader sense, it can ensure coherency by ensuring an utterance is delivered in a similar way, or the model derives a "speaker" from utilizing the existing utterance tokens.
|
||||
* A fresh model *could* have no issues, as it wouldn't be enough of a detriment.
|
||||
* An existing model *could* be coerced with enough time, but I am not patient enough of a man to wait.
|
||||
|
||||
This implementation could utilize a causal attention mask, but both prior "testing" (in loose quotes, as it was due to an oversight) in the previous implementation and careless testing with this implementation shows that it's also a detriment to the model.
|
||||
* Like the above, I imagine a fresh model *could* resolve this issue.
|
||||
|
||||
Lastly, partial attention allows for some clever tricks to train additional things in parallel, such as duration prediction.
|
||||
* previously, this was naively assigned to computing loss against the duration at the last separator of a sequence (the one before the input prompt and output audio)
|
||||
* however, a regression seemed to have caused this quirk to disappear, requiring falling back to explicit duration training
|
||||
* a solution is to properly train this by injecting the "len" predictor token anywhere in the prompt, but take extra care in the 4D attention mask to only allow that token to attend to the input, and not have any other token attend to it.
|
||||
|
||||
### Pure AR
|
||||
|
||||
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each step samples *all* codebooks at one code-frame.
|
||||
|
@ -124,7 +126,7 @@ len_loss_factor: 0.0001 # start with the default for a while to not let duration
|
|||
|
||||
noncausal_masks: True # creates non-causal masks
|
||||
resp_parallel_training: True # trains all codebook levels in parallel
|
||||
len_parallel_training: True # trains length duration alongside normal training
|
||||
len_parallel_training: False # trains length duration alongside normal training
|
||||
|
||||
cfg_cond_dropout_p: 0.02 # was originally 0.3, but I think it's too much after a while
|
||||
cfg_prom_dropout_p: 0.01 # was originally 0.2
|
||||
|
@ -140,10 +142,10 @@ These settings should be avoided:
|
|||
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
|
||||
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
|
||||
* `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end
|
||||
* however, it seems this is a detriment to the model, I imagine because the model could rely on how something sounds earlier on, even if there shouldn't be a direct causal relationship
|
||||
* this could be something that might need to be trained from the very beginning rather than early on, but training existing models does not seem to fare well
|
||||
* `nemo-smaller-llama-8` seemed to have degraded far more than `nemo-larger-llama-8` did. I suppose the head count / size might matter.
|
||||
* this could also have been caused by a regression in the code due to dumb Python aliasing behaviors
|
||||
* however, it's possible this is a detriment itself, but further experimentation is needed
|
||||
* `len_parallel_training`: this uses a clever quirk with how attention works to train duration prediction alongside normal TTS tasks
|
||||
* however, it seems there's a regression that caused this to stop working consistently
|
||||
* disabling this falls back to explicitly training a `len` task (like the old implementation)
|
||||
|
||||
## Benefits and Caveats
|
||||
|
||||
|
|
|
@ -368,7 +368,7 @@ class Model:
|
|||
return [ self ] if not name or self.name == name else []
|
||||
|
||||
def loss_factor(self, k):
|
||||
return self.loss_factors.get(k, 0.0)
|
||||
return self.loss_factors.get(k, 1.0)
|
||||
|
||||
@property
|
||||
def max_levels(self):
|
||||
|
|
|
@ -24,24 +24,28 @@ try:
|
|||
from .codecs.encodec import *
|
||||
except Exception as e:
|
||||
cfg.inference.use_encodec = False
|
||||
raise e
|
||||
_logger.warning(str(e))
|
||||
|
||||
try:
|
||||
from .codecs.vocos import *
|
||||
except Exception as e:
|
||||
cfg.inference.use_vocos = False
|
||||
raise e
|
||||
_logger.warning(str(e))
|
||||
|
||||
try:
|
||||
from .codecs.dac import *
|
||||
except Exception as e:
|
||||
cfg.inference.use_dac = False
|
||||
#raise e
|
||||
_logger.warning(str(e))
|
||||
|
||||
try:
|
||||
from .codecs.nemo import *
|
||||
except Exception as e:
|
||||
cfg.inference.use_nemo = False
|
||||
raise e
|
||||
_logger.warning(str(e))
|
||||
|
||||
@cache
|
||||
|
|
|
@ -61,6 +61,11 @@ class TTS():
|
|||
cfg.format( training=False )
|
||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||
|
||||
# fallback to encodec if no vocos
|
||||
if cfg.audio_backend == "vocos" and not cfg.inference.use_vocos:
|
||||
_logger.warning("Vocos requested but not available, falling back to Encodec...")
|
||||
cfg.set_audio_backend(cfg.audio_backend)
|
||||
|
||||
if amp is None:
|
||||
amp = cfg.inference.amp
|
||||
if dtype is None or dtype == "auto":
|
||||
|
|
|
@ -943,7 +943,7 @@ def example_usage():
|
|||
if task == "stt":
|
||||
prom = [ task ]
|
||||
else:
|
||||
task = "tts" # if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
|
||||
task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
|
||||
|
||||
texts.append( text )
|
||||
proms.append( prom )
|
||||
|
|
|
@ -100,7 +100,7 @@ class FiniteAudioEncoder(nn.Module):
|
|||
if not d_model:
|
||||
d_model = token_dim
|
||||
|
||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim) * 0.02)
|
||||
self.norm = nn.LayerNorm(token_dim) if use_ln else nn.Identity()
|
||||
if use_ffn:
|
||||
|
@ -405,7 +405,8 @@ class Base_V2(nn.Module):
|
|||
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
|
||||
self.len_emb = None # ml.Embedding(11, d_model)
|
||||
|
||||
self.len_emb = nn.Parameter(torch.randn(d_model)) # ugh
|
||||
|
||||
self.audio_emb = None
|
||||
self.proms_emb = None
|
||||
|
@ -640,6 +641,11 @@ class Base_V2(nn.Module):
|
|||
# insert tone token if we're trained for it
|
||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# insert len marker
|
||||
if resps_list is not None:
|
||||
inputs[i].append( ( "len", torch.tensor([resps_list[i].shape[0]]) ) )
|
||||
else:
|
||||
inputs[i].append( ( "len", torch.tensor([0]) ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "len") )
|
||||
# Speech-to-Text prediction task
|
||||
|
@ -758,18 +764,14 @@ class Base_V2(nn.Module):
|
|||
elif name == "timestep" and self.time_emb is not None:
|
||||
embedding = self.time_emb( input )
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
# singleton marker
|
||||
embedding = self.len_emb[None]
|
||||
else:
|
||||
# should probably raise an exception so things aren't processed silently
|
||||
continue
|
||||
|
||||
batch.append(embedding)
|
||||
|
||||
# needed, cringe
|
||||
if task_type == "len":
|
||||
#batch[-1] = torch.cat( [ batch[-1], self.sep[None], self.sep[None] ] )
|
||||
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
|
||||
|
||||
x_list.append( _join( batch, self.sep ) )
|
||||
|
||||
return x_list
|
||||
|
@ -889,7 +891,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
quant_level = quant_levels[batch_index]
|
||||
causal = True
|
||||
causal = False
|
||||
task_type = "tts"
|
||||
dropout_mask = None
|
||||
classifier_level = None
|
||||
|
@ -917,6 +919,7 @@ class Base_V2(nn.Module):
|
|||
# non-tokened tasks
|
||||
if name in non_tokened_names:
|
||||
continue
|
||||
|
||||
# prom can either be a tensor itself or a list of tensors and strings
|
||||
if name == "prom":
|
||||
# expand to list if not a list
|
||||
|
@ -932,6 +935,9 @@ class Base_V2(nn.Module):
|
|||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True )
|
||||
elif name == "len":
|
||||
size = input[0].item()
|
||||
token = torch.tensor([ int(i) for i in str( size ).zfill(5) ], device=device, dtype=torch.int64)
|
||||
# not a special input, inject as-is
|
||||
else:
|
||||
token = input
|
||||
|
@ -956,9 +962,9 @@ class Base_V2(nn.Module):
|
|||
loss_factor = self.loss_factor(name)
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
logit = logits[batch_index][start:end]
|
||||
|
||||
logit = logits[batch_index][start:end]
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
|
@ -968,9 +974,13 @@ class Base_V2(nn.Module):
|
|||
l = self.causal_size
|
||||
loss_targets.append( token[l:].long() ) # shift the target so that token n...
|
||||
loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
|
||||
elif name == "len":
|
||||
loss_targets.append( token.long() )
|
||||
loss_logits.append( logit.squeeze(0) )
|
||||
else:
|
||||
loss_targets.append( token.long() )
|
||||
loss_logits.append( logit )
|
||||
|
||||
loss_factors.append( loss_factor )
|
||||
loss_names.append( name )
|
||||
else:
|
||||
|
@ -1159,6 +1169,8 @@ class Base_V2(nn.Module):
|
|||
lens[1] = input.shape[0] + 1
|
||||
elif name == "tone":
|
||||
lens[1] += 2
|
||||
elif name == "len":
|
||||
lens[2] = 2
|
||||
elif name == "resp":
|
||||
lens[2] = input.shape[0]
|
||||
|
||||
|
@ -1179,8 +1191,8 @@ class Base_V2(nn.Module):
|
|||
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
logits = self.audio_decoder( output.logits )
|
||||
"""
|
||||
# logits = self.audio_decoder( output.logits )
|
||||
|
||||
logits = [ logit for logit in output.logits ]
|
||||
logits_aux = None
|
||||
|
||||
|
@ -1213,11 +1225,16 @@ class Base_V2(nn.Module):
|
|||
decoders_logits = head( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
"""
|
||||
|
||||
# Remove padding
|
||||
logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ]
|
||||
|
||||
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
if classifier_level != "len":
|
||||
continue
|
||||
|
||||
logits[batch_index] = logits[batch_index].view(-1, 5, 10)
|
||||
|
||||
if not training:
|
||||
loss = None
|
||||
stats = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user