From 61ed6628567e4d6c3366ac53938d0f4a7d4ff343 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 7 Dec 2024 12:31:54 -0600 Subject: [PATCH] ACTUALLY actually fix KD-loss (the -inf in the logits was caused by cringecode) --- vall_e/engines/__init__.py | 10 ++++++++++ vall_e/models/base.py | 23 +++++++++++++++++++--- vall_e/train.py | 39 +++++++------------------------------- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 3685035..ba62a32 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -205,6 +205,16 @@ def load_engines(training=True, **model_kwargs): ("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ), ("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ), ] + + # correcting an oversight + if model.config.experimental.split_classifiers and "len" in model.capabilities: + len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"]) + keys.append((f"classifiers.proj.{len_idx}.weight", 11)) + keys.append((f"classifiers.proj.{len_idx}.bias", 11)) + + keys.append((f"classifiers.proj.{nar_0_idx}.weight", 1024)) + keys.append((f"classifiers.proj.{nar_0_idx}.bias", 1024)) + for k, tokens in keys: if k not in state: continue diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0b983b6..c83e383 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -265,7 +265,7 @@ class Classifiers(nn.Module): return names return [ self.names.index(name) for name in names ] - def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None ) -> Tensor: + def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None, stack = False ) -> Tensor: dtype = xi.dtype device = xi.device @@ -278,8 +278,12 @@ class Classifiers(nn.Module): levels = [ self.names.index(name) for name in names ] xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ] + if not stack: + return xi + # pad if needed # to-do: validate that this causes ZERO issues + # addendum: this does cause problems max_size = max([ x.shape[-1] for x in xi ]) xi = [ #x if l == 0 else @@ -460,15 +464,18 @@ class Base(nn.Module): n_resp_tokens = n_audio_tokens + 1 l_tokens = [n_resp_tokens] * self.n_resp_levels resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )] + classifier_l_tokens = [n_resp_tokens] * self.n_resp_levels # NAR-len model elif "len" in self.capabilities: # +1 to include the stop or mask token n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) if "ar" in self.capabilities: l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] + classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1] resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0'] else: l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] # AR+NAR model else: @@ -476,12 +483,13 @@ class Base(nn.Module): n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - classifier_l_tokens = l_tokens + [ n_text_tokens ] + classifier_l_tokens += [ n_text_tokens ] classifier_l_names = resp_l_names + [ "stt" ] if "len" in self.capabilities: - classifier_l_tokens += [ n_text_tokens ] + classifier_l_tokens += [ 11 ] classifier_l_names += ["len"] self.unified_position_ids = unified_position_ids @@ -1577,6 +1585,15 @@ class Base(nn.Module): if hidden_states is not None: for i, state in enumerate( hidden_states ): hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] + + # corrections + """ + for batch_index, classifier_level in enumerate( classifier_levels ): + if classifier_level == "len" and logits[batch_index].shape[1] > 11: + logits[batch_index] = logits[batch_index][:,:11] + elif classifier_level == "NAR:0:0" and logits[batch_index].shape[1] > 1024: + logits[batch_index] = logits[batch_index][:,:1024] + """ # compute loss if the target is given if not training: diff --git a/vall_e/train.py b/vall_e/train.py index 0c2800a..5c3a221 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -58,16 +58,12 @@ def train_feeder(engine, batch, teacher=None): A = cfg.hyperparameters.teacher_alpha L = cfg.hyperparameters.teacher_loss_fn - # I don't know what to call the last one - if L not in ["kl", "mse"]: - L = "kl" - # determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways) # we could recreate the target sequence with the ignore indices put in, but that's agony student_logits = [ logit / T for logit in output.logits ] teacher_logits = [ logit / T for logit in teacher_output.logits ] - if engine.module.ignore_inputs_for_loss or True: + if engine.module.ignore_inputs_for_loss: task_outputs = { "tts": "resp", "stt": "text", @@ -88,38 +84,17 @@ def train_feeder(engine, batch, teacher=None): student_logits = [ logit[-l:] for logit, l in zip( student_logits, output_lens ) ] teacher_logits = [ logit[-l:] for logit, l in zip( teacher_logits, output_lens ) ] - # split even further because tensor shapes may change - # losses can be done per-token anyways so it's fine - student_logits = [ [ l for l in logit ] for logit in student_logits ] - teacher_logits = [ [ l for l in logit ] for logit in teacher_logits ] - - # filter out logits that are / would inf - # this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier) - for batch_index in range( batch_size ): - for token_index in range( len( student_logits[batch_index] ) ): - filter = -float("inf") # for some unknown reason -inf is poisoning the logits - mask_a = student_logits[batch_index][token_index] == filter - mask_b = teacher_logits[batch_index][token_index] == filter - # remove them from both distributions to keep things synced - mask = mask_a | mask_b - - student_logits[batch_index][token_index] = torch.masked_select( student_logits[batch_index][token_index], ~mask ) - teacher_logits[batch_index][token_index] = torch.masked_select( teacher_logits[batch_index][token_index], ~mask ) - - # kl-divergence operates on probability distributions - # teacher doesn't need to be in logspace but it makes things easier to do so and just pass log_target=True if L == "kl": - student_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in student_logits ] - teacher_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in teacher_logits ] + student_probs = [ F.log_softmax( logit, dim=-1 ) for logit in student_logits ] + teacher_probs = [ F.log_softmax( logit, dim=-1 ) for logit in teacher_logits ] - soft_losses = [ sum([ F.kl_div( s, t, reduction='sum', log_target=True ) for s, t in zip( student, teacher ) ]) / len(student) for student, teacher in zip( student_probs, teacher_probs ) ] - # mse shouldn't operate on probability distributions - elif L == "mse": - soft_losses = [ sum([ F.mse_loss( s, t ) for s, t in zip(student, teacher) ]) / len(student) for student, teacher in zip( student_logits, teacher_logits ) ] + soft_losses = [ F.kl_div( student, teacher, reduction='sum', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ] + elif L == "mse": + soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_logits, teacher_logits ) ] for k in engine.module.loss.keys(): engine.module.loss[k] *= (1.0 - A) - engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size + engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats")