agnostified KD

This commit is contained in:
mrq 2024-12-06 23:53:46 -06:00
parent 953d3eb030
commit 34a66e1052
5 changed files with 78 additions and 83 deletions

View File

@ -280,9 +280,6 @@ class ModelExperimentalSettings:
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
teacher_temperature: float = 1.0
# I really need to clean this up
@dataclass()
class Model:
@ -454,7 +451,7 @@ class LoRA:
if not self.rvq_levels:
return True
return level in self.rvq_levels
@dataclass()
class Hyperparameters:
batch_size: int = 8 # number of samples per training batch
@ -476,6 +473,10 @@ class Hyperparameters:
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
teacher_temperature: float = 1.0
teacher_loss_fn: str = "kl" # kl | mse
@dataclass()
class Evaluation:

View File

@ -48,8 +48,6 @@ class AR_NAR(Base):
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
teacher = None,
):
# deduce batch_size
if text_list is not None:
@ -204,7 +202,6 @@ class AR_NAR(Base):
return super().forward(
inputs=inputs,
quant_levels=quant_levels,
teacher=teacher,
)
def forward_nar_masked(
@ -842,7 +839,6 @@ class AR_NAR(Base):
len_list: list[Tensor] | None = None,
training: bool | None = None,
teacher = None,
disable_tqdm=False,
use_lora=None,
@ -879,8 +875,6 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
teacher=teacher,
)
# is NAR

View File

@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding
from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
@ -442,8 +442,6 @@ class Base(nn.Module):
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
@ -493,8 +491,6 @@ class Base(nn.Module):
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.teacher_alpha = teacher_alpha
self.teacher_temperature = teacher_temperature
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
"""
@ -891,7 +887,7 @@ class Base(nn.Module):
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
return Logits(x, state, aux_loss, attentions, hidden_states, None)
return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
@ -1456,8 +1452,6 @@ class Base(nn.Module):
output_attentions: bool = False,
output_hidden_states: bool = False,
teacher = None,
):
# return early if it's "good" enough"
# lambda because we need to capture the classifier_levels and mask
@ -1503,7 +1497,7 @@ class Base(nn.Module):
# derive quant levels from inputs if not provided
if quant_levels is None:
quant_levels = self.get_input( inputs, "quant_level" )
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
x_list = self.inputs_to_embeddings( inputs, quant_levels )
@ -1615,66 +1609,6 @@ class Base(nn.Module):
# to-do: instead make the cirriculum rely on samples processed instead of steps
self.training_steps += 1 # batch_size
# get soft targets from teacher
# required to do it in here because the batch is further processed within the model (because of per-model config)
if teacher is not None:
# grab the teacher's logits
with torch.no_grad():
teacher_output = teacher.forward_super(
inputs=inputs,
quant_levels=quant_levels,
)
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
# we could recreate the target sequence with the ignore indices put in, but that's agony
output_lens = [ 0 for _ in range(batch_size) ]
for batch_index, batch in enumerate(inputs):
task_type = "tts"
for name, input in batch:
if name == "task":
task_type = input
for name, input in batch:
if name == task_outputs.get(task_type, name):
output_lens[batch_index] = input.shape[0]
# KD hyperparameters
T = self.teacher_temperature
A = self.teacher_alpha
# create probability distributions (literature says to have the students already log'd but not the teacher)
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ]
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
# filter out logits that are / would inf
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
for batch_index, output_len in enumerate( output_lens ):
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
mask = mask_a | mask_b
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
#soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ]
#soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size
"""
# flatten to a single sequence of token-probabilities
# but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025)
student_probs = torch.concat( student_probs, dim = 0 )
teacher_probs = torch.concat( teacher_probs, dim = 0 )
soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size
"""
# mix if not nan
if not torch.isnan(soft_loss).any():
for k in loss.keys():
loss[k] *= (1.0 - A)
loss['kl'] = soft_loss * A
# include any additional losses (for example: MoE router)
if output.loss is not None:
loss["aux_loss"] = output.loss
@ -1683,7 +1617,7 @@ class Base(nn.Module):
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer)
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer)
def sample(
self,

View File

@ -31,7 +31,7 @@ def train_feeder(engine, batch, teacher=None):
batch_size = len(batch["text"])
engine.current_batch_size = batch_size
engine(
output = engine(
text_list=batch["text"],
proms_list=batch["proms"],
resps_list=batch["resps"],
@ -40,9 +40,75 @@ def train_feeder(engine, batch, teacher=None):
task_list=batch["task"],
training=True,
teacher=teacher,
)
# get soft targets from teacher
if teacher is not None:
# extract inputs forwarded to model
inputs = output.inputs
# grab the teacher's logits
with torch.no_grad():
teacher_output = teacher.forward_super(
inputs=inputs,
)
# KD hyperparameters
T = cfg.hyperparameters.teacher_temperature
A = cfg.hyperparameters.teacher_alpha
L = cfg.hyperparameters.teacher_loss_fn
# I don't know what to call the last one
if L not in ["kl", "mse"]:
L = "kd"
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
# we could recreate the target sequence with the ignore indices put in, but that's agony
if not engine.module.ignore_inputs_for_loss:
student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ]
teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ]
else:
task_outputs = {
"tts": "resp",
"stt": "text",
"len": "len",
}
output_lens = [ 0 for _ in range(batch_size) ]
for batch_index, _batch in enumerate(inputs):
task_type = "tts"
for name, input in _batch:
if name == "task":
task_type = input
for name, input in _batch:
if name == task_outputs.get(task_type, name):
output_lens[batch_index] = input.shape[0]
# create probability distributions (literature says to have the students already log'd but not the teacher)
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ]
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
# filter out logits that are / would inf
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
for batch_index in range( batch_size ):
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
mask = mask_a | mask_b
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
if L == "kl":
soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ]
elif L == "mse":
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
else:
soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
for k in engine.module.loss.keys():
engine.module.loss[k] *= (1.0 - A)
engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size
losses = engine.gather_attribute("loss")
stat = engine.gather_attribute("stats")

View File

@ -143,7 +143,7 @@ def train(
# validate if there's at least one model to train
found = False
for name, engine in engines.items():
if engine.training:
if engine._training:
found = True
break
if not found: