ACTUALLY actually fix KD-loss (the -inf in the logits was caused by cringecode)
This commit is contained in:
parent
f97e8b0c7f
commit
61ed662856
|
@ -205,6 +205,16 @@ def load_engines(training=True, **model_kwargs):
|
|||
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
|
||||
("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ),
|
||||
]
|
||||
|
||||
# correcting an oversight
|
||||
if model.config.experimental.split_classifiers and "len" in model.capabilities:
|
||||
len_idx, nar_0_idx = model.classifiers.indices(["len", "NAR:0:0"])
|
||||
keys.append((f"classifiers.proj.{len_idx}.weight", 11))
|
||||
keys.append((f"classifiers.proj.{len_idx}.bias", 11))
|
||||
|
||||
keys.append((f"classifiers.proj.{nar_0_idx}.weight", 1024))
|
||||
keys.append((f"classifiers.proj.{nar_0_idx}.bias", 1024))
|
||||
|
||||
for k, tokens in keys:
|
||||
if k not in state:
|
||||
continue
|
||||
|
|
|
@ -265,7 +265,7 @@ class Classifiers(nn.Module):
|
|||
return names
|
||||
return [ self.names.index(name) for name in names ]
|
||||
|
||||
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None, stack = False ) -> Tensor:
|
||||
dtype = xi.dtype
|
||||
device = xi.device
|
||||
|
||||
|
@ -278,8 +278,12 @@ class Classifiers(nn.Module):
|
|||
levels = [ self.names.index(name) for name in names ]
|
||||
|
||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||
if not stack:
|
||||
return xi
|
||||
|
||||
# pad if needed
|
||||
# to-do: validate that this causes ZERO issues
|
||||
# addendum: this does cause problems
|
||||
max_size = max([ x.shape[-1] for x in xi ])
|
||||
xi = [
|
||||
#x if l == 0 else
|
||||
|
@ -460,15 +464,18 @@ class Base(nn.Module):
|
|||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||
classifier_l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
# NAR-len model
|
||||
elif "len" in self.capabilities:
|
||||
# +1 to include the stop or mask token
|
||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
if "ar" in self.capabilities:
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
|
||||
else:
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
# AR+NAR model
|
||||
else:
|
||||
|
@ -476,12 +483,13 @@ class Base(nn.Module):
|
|||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
classifier_l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
|
||||
classifier_l_tokens = l_tokens + [ n_text_tokens ]
|
||||
classifier_l_tokens += [ n_text_tokens ]
|
||||
classifier_l_names = resp_l_names + [ "stt" ]
|
||||
|
||||
if "len" in self.capabilities:
|
||||
classifier_l_tokens += [ n_text_tokens ]
|
||||
classifier_l_tokens += [ 11 ]
|
||||
classifier_l_names += ["len"]
|
||||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
|
@ -1577,6 +1585,15 @@ class Base(nn.Module):
|
|||
if hidden_states is not None:
|
||||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
|
||||
# corrections
|
||||
"""
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
if classifier_level == "len" and logits[batch_index].shape[1] > 11:
|
||||
logits[batch_index] = logits[batch_index][:,:11]
|
||||
elif classifier_level == "NAR:0:0" and logits[batch_index].shape[1] > 1024:
|
||||
logits[batch_index] = logits[batch_index][:,:1024]
|
||||
"""
|
||||
|
||||
# compute loss if the target is given
|
||||
if not training:
|
||||
|
|
|
@ -58,16 +58,12 @@ def train_feeder(engine, batch, teacher=None):
|
|||
A = cfg.hyperparameters.teacher_alpha
|
||||
L = cfg.hyperparameters.teacher_loss_fn
|
||||
|
||||
# I don't know what to call the last one
|
||||
if L not in ["kl", "mse"]:
|
||||
L = "kl"
|
||||
|
||||
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
||||
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
||||
student_logits = [ logit / T for logit in output.logits ]
|
||||
teacher_logits = [ logit / T for logit in teacher_output.logits ]
|
||||
|
||||
if engine.module.ignore_inputs_for_loss or True:
|
||||
if engine.module.ignore_inputs_for_loss:
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"stt": "text",
|
||||
|
@ -88,38 +84,17 @@ def train_feeder(engine, batch, teacher=None):
|
|||
student_logits = [ logit[-l:] for logit, l in zip( student_logits, output_lens ) ]
|
||||
teacher_logits = [ logit[-l:] for logit, l in zip( teacher_logits, output_lens ) ]
|
||||
|
||||
# split even further because tensor shapes may change
|
||||
# losses can be done per-token anyways so it's fine
|
||||
student_logits = [ [ l for l in logit ] for logit in student_logits ]
|
||||
teacher_logits = [ [ l for l in logit ] for logit in teacher_logits ]
|
||||
|
||||
# filter out logits that are / would inf
|
||||
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
||||
for batch_index in range( batch_size ):
|
||||
for token_index in range( len( student_logits[batch_index] ) ):
|
||||
filter = -float("inf") # for some unknown reason -inf is poisoning the logits
|
||||
mask_a = student_logits[batch_index][token_index] == filter
|
||||
mask_b = teacher_logits[batch_index][token_index] == filter
|
||||
# remove them from both distributions to keep things synced
|
||||
mask = mask_a | mask_b
|
||||
|
||||
student_logits[batch_index][token_index] = torch.masked_select( student_logits[batch_index][token_index], ~mask )
|
||||
teacher_logits[batch_index][token_index] = torch.masked_select( teacher_logits[batch_index][token_index], ~mask )
|
||||
|
||||
# kl-divergence operates on probability distributions
|
||||
# teacher doesn't need to be in logspace but it makes things easier to do so and just pass log_target=True
|
||||
if L == "kl":
|
||||
student_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in student_logits ]
|
||||
teacher_probs = [ [ F.log_softmax( l, dim=-1 ) for l in logit ] for logit in teacher_logits ]
|
||||
student_probs = [ F.log_softmax( logit, dim=-1 ) for logit in student_logits ]
|
||||
teacher_probs = [ F.log_softmax( logit, dim=-1 ) for logit in teacher_logits ]
|
||||
|
||||
soft_losses = [ sum([ F.kl_div( s, t, reduction='sum', log_target=True ) for s, t in zip( student, teacher ) ]) / len(student) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
# mse shouldn't operate on probability distributions
|
||||
elif L == "mse":
|
||||
soft_losses = [ sum([ F.mse_loss( s, t ) for s, t in zip(student, teacher) ]) / len(student) for student, teacher in zip( student_logits, teacher_logits ) ]
|
||||
soft_losses = [ F.kl_div( student, teacher, reduction='sum', log_target=True ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
elif L == "mse":
|
||||
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_logits, teacher_logits ) ]
|
||||
|
||||
for k in engine.module.loss.keys():
|
||||
engine.module.loss[k] *= (1.0 - A)
|
||||
engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size
|
||||
engine.module.loss[L] = torch.stack(soft_losses).sum() * A * (T ** 2) / batch_size
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
|
Loading…
Reference in New Issue
Block a user