QoL so I can stop having to manually inject different configs

This commit is contained in:
mrq 2025-03-06 14:48:14 -06:00
parent 0d809561c6
commit 5cd71ef238
4 changed files with 97 additions and 78 deletions

View File

@ -25,6 +25,8 @@ Training is (not-so-obviously) not dependent on:
* for the old (`base.py`) implementation, further experimentation is required, but from what I remember the smaller models don't have speech emerge as fast, while the larger size models don't seem to benefit much.
* for the new (`base_v2.py`) implementation, it seems that model size doesn't affect quality at all, at least in the primary phase of getting it to speak.
* the "training progression" (how the loss/accuracy/grad norm curves look) are about the exact same between the "normal" (1024 dim, 12 layers, 12 heads) size and the "half" (512 dim, 12 layers, 8 heads) size, and presumably for the "double" size (1538 dim, 24 layers, 24 heads).
* the "half" size might actually be too small for it to have enough "capacity" to attend to all the speakers I'm training against.
* per E2/F5's paper, a size of 1024 dim, 4x FFN, 16 heads, 24 layers might be preferable?
* it *probably* is necessary to have a larger model to better adhere to the reference clip, but experimentation is not at the point yet to verify this.
* the audio codec, to an extent
* for the old (`base.py`) implementation, EnCodec only seems to work well for it, as DAC might requires some hacks or patience for the higher RVQ levels to train, while `nvidia/audio-codec-44khz` requires an exotic training cirriculum, assumedly.
@ -40,7 +42,6 @@ A training paradigm that seems to work for me is to:
* this also benefits from the model training on a variety of durations to avoid it overfitting for the last duration set trained against
* optionally, you can sample based on speaker instead to balance out the speakers trained against, but this isn't all that necessary
Training under `float16` (+AMP) should be fairly simple, but it's practically required to use the `deepspeed` backend.
* This is because `deepspeed` will automatically wrap the optimizer to handle training under `float16` and does some extra magic for stability. The `local` backend does do loss scaling, but not the extra steps.
* Training under `bfloat16` does not have to worry about this, but I feel `bfloat16` training sessions don't have a specific training trait that `float16` does have, personally.
@ -72,7 +73,6 @@ The optimizer used *mostly* doesn't matter, as AdamW seems to get moving faster,
* `APOLLO` needs more testing, but seemed adequate in cursory tests
* `Muon` requires much more testing, but absolutely cannot be used for predicting tokens in place (NAR demasking), and requires `cfg.model.experimental.predict_causally=True`
* I honestly don't think it gives good enough results from curosry tests for this application
* `Adagrad` surprisingly seems to "fix" (for now) my problems with the loss / accuracy bouncing.
## Try Me

View File

@ -281,6 +281,12 @@ class ModelExperimentalSettings:
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
# "auto" will pick best for codec
# "decreasing" will do the RVQ strat (prioritize lower levels)
# "normal" will do the FSQ strat (prioritize midrange)
# "equal" or "none" will set do no leveling
# list of floats to manually set
# these technically should be as hyperparameters
# performs token dropout to compensate for errors
@ -561,8 +567,9 @@ class DeepSpeed:
optimizer: bool = True # use DeepSpeed optimizer wrapper
amp: bool = False # use DeepSpeed's AMP (requires some other package installed apparently)
loss_scale_window: int = 100
min_loss_scale: float = 8192.0
loss_scale_window: int = 1000
min_loss_scale: float = 32768.0
loss_scale = 0.0
config: dict = field(default_factory=lambda: {}) # to pass through deepspeed config
@ -614,9 +621,9 @@ class DeepSpeed:
"fp16": {
"enabled": cfg.trainer.weight_dtype.lower() == "float16",
"auto_cast": True, # ???
"loss_scale_window": self.loss_scale_window, # raise every 100 consecutive good steps
"min_loss_scale": self.min_loss_scale, # loss scale hitting 8K fries the model, 16K is fine but 32K is comfy
"loss_scale": 0.0 if cfg.trainer.scale_loss else 1.0,
"loss_scale_window": self.loss_scale_window,
"min_loss_scale": self.min_loss_scale,
"loss_scale": self.loss_scale if cfg.trainer.scale_loss else 1.0, # use defined loss scale (defaults to 0, which is dynamic) if requested, or 1.0 (none) if not
},
"bf16": {
"enabled": cfg.trainer.weight_dtype.lower() == "bfloat16",

View File

@ -1036,7 +1036,7 @@ def example_usage():
if task == "stt":
prom = [ task ]
else:
task = "tts" if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
task = "tts" # if random.random() > 0.1 or "len" not in cfg.model.capabilities else "len"
texts.append( text )
proms.append( prom )

View File

@ -300,7 +300,7 @@ class Base_V2(nn.Module):
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
@ -309,6 +309,27 @@ class Base_V2(nn.Module):
n_langs = config.langs if config is not None else 2
n_tones = config.tones if config is not None else 1
if audio_level_loss_factors == "auto":
audio_level_loss_factors = "normal" if n_audio_tokens == 1000 else "decreasing"
if audio_level_loss_factors == "decreasing":
audio_level_loss_factors = [1.0 / (i + 1) for i in range(n_resp_levels)]
elif audio_level_loss_factors == "normal":
if n_resp_levels == 8:
audio_level_loss_factors = [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5]
else:
center = n_resp_levels // 2
audio_level_loss_factors = [1.0 - abs(i - center) / n_resp_levels for i in range(n_resp_levels)]
# to-do: proper cirriculum
# prioritizes midrange, maybe good for epoch 0?
# [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5]
# deprioritizes midrange, good for epoch 1?
# [0.875, 0.75, 0.625, 0.5, 0.5, 0.625, 0.75, 0.875]
elif audio_level_loss_factors == "equal":
audio_level_loss_factors = [1.0 for _ in range(n_resp_levels)]
if attention_backend == "auto":
attention_backend = "sdpa"
@ -320,18 +341,6 @@ class Base_V2(nn.Module):
if attention_backend not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
# to-do: deduce nemo better-er
if n_audio_tokens == 1000:
# assume midrage contains important details
center = n_resp_levels // 2
audio_level_weights = [1.0 - abs(i - center) / n_resp_levels for i in range(n_resp_levels)]
# to-do: proper cirriculum
# prioritizes midrange, maybe good for epoch 0?
# [0.5, 0.625, 0.75, 0.875, 0.875, 0.75, 0.625, 0.5]
# deprioritizes midrange, good for epoch 1?
# [0.875, 0.75, 0.625, 0.5, 0.5, 0.625, 0.75, 0.875]
self.training = training
self.teaching = False
self.config = config
@ -380,7 +389,7 @@ class Base_V2(nn.Module):
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.audio_level_weights = audio_level_weights
self.audio_level_loss_factors = audio_level_loss_factors
self.logit_normalization = logit_normalization
self.sep = nn.Parameter(torch.randn(d_model))
@ -391,6 +400,7 @@ class Base_V2(nn.Module):
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
self.len_emb = ml.Embedding(11, d_model)
# to-do: un-autoregressivefy len inferencing, and have it trained parallel to normal training through a separate head or something
self.audio_emb = None
self.proms_emb = None
@ -867,7 +877,7 @@ class Base_V2(nn.Module):
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
level_weights = self.audio_level_weights
level_loss_factor = self.audio_level_loss_factors
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -876,6 +886,7 @@ class Base_V2(nn.Module):
return input
k_lo, k_hi = 1, 20
def _calc_loss( logit, sequence, causal = True, level = None ):
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
@ -906,37 +917,38 @@ class Base_V2(nn.Module):
sequence = sequence.reshape(-1)
nll = None
acc_k1 = None
acc_k_lo = None
if compute_hard_loss:
reduction = 'mean' if not batched else 'none'
weight = level_weights[level] if level is not None and not batched else 1
weight = level_loss_factor[level] if level is not None and not batched else 1
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight
# manually weigh each level
if batched:
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device)
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factor, device=device)
if compute_acc:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 1,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k1 = accuracy_metric( logit, sequence )
if logit.shape[0] >= k_lo:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 1,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k_lo = accuracy_metric( logit, sequence )
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = min(logit.shape[0], 80),
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k80 = accuracy_metric( logit, sequence )
if logit.shape[0] >= k_hi:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = 20,
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
acc_k_hi = accuracy_metric( logit, sequence )
return nll, acc_k1, acc_k80
return nll, acc_k_lo, acc_k_hi
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
@ -1022,7 +1034,7 @@ class Base_V2(nn.Module):
continue
if logits[batch_index].dim() < 3:
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][start:end], token.long(), causal )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][start:end], token.long(), causal )
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = quant_level
@ -1035,31 +1047,31 @@ class Base_V2(nn.Module):
name = f'{name}[{level}]'
sequence = token if token.dim() <= 1 else token[:, level]
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
else:
sequence = token.t()
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
if nll is not None:
nll = nll.mean()
loss_key = f'{name}.nll'
acc_k1_key = f'{name}.acc[k=1]'
acc_k80_key = f'{name}.acc[k=80]'
acc_k_lo_key = f'{name}.acc[k={k_lo}]'
acc_k_hi_key = f'{name}.acc[k={k_hi}]'
if nll is not None:
if loss_key not in loss:
loss[loss_key] = []
loss[loss_key].append( nll * loss_factor )
if acc_k1 is not None:
if acc_k1_key not in stats:
stats[acc_k1_key] = []
stats[acc_k1_key].append( acc_k1 )
if acc_k_lo is not None:
if acc_k_lo_key not in stats:
stats[acc_k_lo_key] = []
stats[acc_k_lo_key].append( acc_k_lo )
if acc_k80 is not None:
if acc_k80_key not in stats:
stats[acc_k80_key] = []
stats[acc_k80_key].append( acc_k80 )
if acc_k_hi is not None:
if acc_k_hi_key not in stats:
stats[acc_k_hi_key] = []
stats[acc_k_hi_key].append( acc_k_hi )
# add to list
else:
target.append( token )
@ -1069,7 +1081,7 @@ class Base_V2(nn.Module):
if not self.config.loss_factors:
if logits[batch_index].dim() < 3:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index], sequence, causal )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index], sequence, causal )
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = 0
@ -1080,45 +1092,45 @@ class Base_V2(nn.Module):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, acc_k1, acc_k80 = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
nll, acc_k_lo, acc_k_hi = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
else:
nlls = []
acc_k1s = []
acc_k80s = []
acc_k_los = []
acc_k_his = []
for level, logit in enumerate( logits[batch_index] ):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, acc_k1, acc_k80 = _calc_loss( logit, sequence, causal, level )
nll, acc_k_lo, acc_k_hi = _calc_loss( logit, sequence, causal, level )
if nll:
nlls.append( nll )
if acc_k1:
acc_k1s.append( acc_k1 )
if acc_k80:
acc_k80s.append( acc_k80 )
if acc_k_lo:
acc_k_los.append( acc_k_lo )
if acc_k_hi:
acc_k_his.append( acc_k_hi )
if nlls:
nll = sum(nlls) / len(nlls)
if acc_k1s:
acc_k1 = sum(acc_k1s) / len(acc_k1s)
if acc_k80s:
acc_k80 = sum(acc_k80s) / len(acc_k80s)
if acc_k_los:
acc_k_lo = sum(acc_k_los) / len(acc_k_los)
if acc_k_his:
acc_k_hi = sum(acc_k_his) / len(acc_k_his)
if nll is not None:
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if acc_k1 is not None:
if 'acc[k=1]' not in stats:
stats['acc[k=1]'] = []
stats["acc[k=1]"].append( acc_k1 )
if acc_k_lo is not None:
if f'acc[k={k_lo}]' not in stats:
stats[f'acc[k={k_lo}]'] = []
stats[f"acc[k={k_lo}]"].append( acc_k_lo )
if acc_k80 is not None:
if 'acc[k=80]' not in stats:
stats['acc[k=80]'] = []
stats["acc[k=80]"].append( acc_k80 )
if acc_k_hi is not None:
if f'acc[k={k_hi}]' not in stats:
stats[f'acc[k={k_hi}]'] = []
stats[f"acc[k={k_hi}]"].append( acc_k_hi )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }