ACTUALLY actually fix KD-loss (the -inf in the logits was caused by cringecode)

This commit is contained in:
mrq 2024-12-07 12:31:54 -06:00
parent f97e8b0c7f
commit 61ed662856
3 changed files with 37 additions and 35 deletions

View File

@ -205,6 +205,16 @@ def load_engines(training=True, **model_kwargs):
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ),
]
# correcting an oversight
if model.config.experimental.split_classifiers and "len" in model.capabilities:
len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"])
keys.append((f"classifiers.proj.{len_idx}.weight", 11))
keys.append((f"classifiers.proj.{len_idx}.bias", 11))
keys.append((f"classifiers.proj.{nar_0_idx}.weight", 1024))
keys.append((f"classifiers.proj.{nar_0_idx}.bias", 1024))
for k, tokens in keys:
if k not in state:
continue

View File

@ -265,7 +265,7 @@ class Classifiers(nn.Module):
return names
return [ self.names.index(name) for name in names ]
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None ) -> Tensor:
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None, stack = False ) -> Tensor:
dtype = xi.dtype
device = xi.device
@ -278,8 +278,12 @@ class Classifiers(nn.Module):
levels = [ self.names.index(name) for name in names ]
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
if not stack:
return xi
# pad if needed
# to-do: validate that this causes ZERO issues
# addendum: this does cause problems
max_size = max([ x.shape[-1] for x in xi ])
xi = [
#x if l == 0 else
@ -460,15 +464,18 @@ class Base(nn.Module):
n_resp_tokens = n_audio_tokens + 1
l_tokens = [n_resp_tokens] * self.n_resp_levels
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
classifier_l_tokens = [n_resp_tokens] * self.n_resp_levels
# NAR-len model
elif "len" in self.capabilities:
# +1 to include the stop or mask token
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
if "ar" in self.capabilities:
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
else:
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
# AR+NAR model
else:
@ -476,12 +483,13 @@ class Base(nn.Module):
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
classifier_l_tokens = l_tokens + [ n_text_tokens ]
classifier_l_tokens += [ n_text_tokens ]
classifier_l_names = resp_l_names + [ "stt" ]
if "len" in self.capabilities:
classifier_l_tokens += [ n_text_tokens ]
classifier_l_tokens += [ 11 ]
classifier_l_names += ["len"]
self.unified_position_ids = unified_position_ids
@ -1577,6 +1585,15 @@ class Base(nn.Module):
if hidden_states is not None:
for i, state in enumerate( hidden_states ):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# corrections
"""
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "len" and logits[batch_index].shape[1] > 11:
logits[batch_index] = logits[batch_index][:,:11]
elif classifier_level == "NAR:0:0" and logits[batch_index].shape[1] > 1024:
logits[batch_index] = logits[batch_index][:,:1024]
"""
# compute loss if the target is given
if not training:

View File

@ -58,16 +58,12 @@ def train_feeder(engine, batch, teacher=None):
A = cfg.hyperparameters.teacher_alpha
L = cfg.hyperparameters.teacher_loss_fn
# I don't know what to call the last one
if L not in ["kl", "mse"]:
L = "kl"
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
# we could recreate the target sequence with the ignore indices put in, but that's agony
student_logits = [ logit / T for logit in output.logits ]
teacher_logits = [ logit / T for logit in teacher_output.logits ]
if engine.module.ignore_inputs_for_loss or True:
if engine.module.ignore_inputs_for_loss:
task_outputs = {
"tts": "resp",
"stt": "text",
@ -88,38 +84,17 @@ def train_feeder(engine, batch, teacher=None):
student_logits = [ logit[-l:] for logit, l in zip( student_logits, output_lens ) ]
teacher_logits = [ logit[-l:] for logit, l in zip( teacher_logits, output_lens ) ]
# split even further because tensor shapes may change
# losses can be done per-token anyways so it's fine
student_logits = [ [ l for l in logit ] for logit in student_logits ]
teacher_logits = [ [ l for l in logit ] for logit in teacher_logits ]
# filter out logits that are / would inf
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
for batch_index in range( batch_size ):
for token_index in range( len( student_logits[batch_index] ) ):
filter = -float("inf") # for some unknown reason -inf is poisoning the logits
mask_a = student_logits[batch_index][token_index] == filter
mask_b = teacher_logits[batch_index][token_index] == filter
# remove them from both distributions to keep things synced
mask = mask_a | mask_b
student_logits[batch_index][token_index] = torch.masked_select( student_logits[batch_index][token_index], ~mask )
teacher_logits[batch_index][token_index] = torch.masked_select( teacher_logits[batch_index][token_index], ~mask )
# kl-divergence operates on probability distributions
# teacher doesn't need to be in logspace but it makes things easier to do so and just pass log_target=True
if L == "kl":
student_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in student_logits ]
teacher_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in teacher_logits ]
student_probs = [ F.log_softmax( logit, dim=-1 ) for logit in student_logits ]
teacher_probs = [ F.log_softmax( logit, dim=-1 ) for logit in teacher_logits ]
soft_losses = [ sum([ F.kl_div( s, t, reduction='sum', log_target=True ) for s, t in zip( student, teacher ) ]) / len(student) for student, teacher in zip( student_probs, teacher_probs ) ]
# mse shouldn't operate on probability distributions
elif L == "mse":
soft_losses = [ sum([ F.mse_loss( s, t ) for s, t in zip(student, teacher) ]) / len(student) for student, teacher in zip( student_logits, teacher_logits ) ]
soft_losses = [ F.kl_div( student, teacher, reduction='sum', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ]
elif L == "mse":
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_logits, teacher_logits ) ]
for k in engine.module.loss.keys():
engine.module.loss[k] *= (1.0 - A)
engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size
engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")