agnostified KD
This commit is contained in:
parent
953d3eb030
commit
34a66e1052
|
@ -280,9 +280,6 @@ class ModelExperimentalSettings:
|
|||
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||
|
||||
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
||||
teacher_temperature: float = 1.0
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
|
@ -454,7 +451,7 @@ class LoRA:
|
|||
if not self.rvq_levels:
|
||||
return True
|
||||
return level in self.rvq_levels
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
batch_size: int = 8 # number of samples per training batch
|
||||
|
@ -476,6 +473,10 @@ class Hyperparameters:
|
|||
|
||||
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
||||
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
||||
|
||||
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
||||
teacher_temperature: float = 1.0
|
||||
teacher_loss_fn: str = "kl" # kl | mse
|
||||
|
||||
@dataclass()
|
||||
class Evaluation:
|
||||
|
|
|
@ -48,8 +48,6 @@ class AR_NAR(Base):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
teacher = None,
|
||||
):
|
||||
# deduce batch_size
|
||||
if text_list is not None:
|
||||
|
@ -204,7 +202,6 @@ class AR_NAR(Base):
|
|||
return super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
teacher=teacher,
|
||||
)
|
||||
|
||||
def forward_nar_masked(
|
||||
|
@ -842,7 +839,6 @@ class AR_NAR(Base):
|
|||
len_list: list[Tensor] | None = None,
|
||||
|
||||
training: bool | None = None,
|
||||
teacher = None,
|
||||
|
||||
disable_tqdm=False,
|
||||
use_lora=None,
|
||||
|
@ -879,8 +875,6 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
|
||||
teacher=teacher,
|
||||
)
|
||||
|
||||
# is NAR
|
||||
|
|
|
@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding
|
|||
from ..data import get_task_symmap
|
||||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
|
@ -442,8 +442,6 @@ class Base(nn.Module):
|
|||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
|
||||
teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5
|
||||
|
||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||
|
@ -493,8 +491,6 @@ class Base(nn.Module):
|
|||
self.masking_ratio = masking_ratio
|
||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
self.teacher_alpha = teacher_alpha
|
||||
self.teacher_temperature = teacher_temperature
|
||||
|
||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||
"""
|
||||
|
@ -891,7 +887,7 @@ class Base(nn.Module):
|
|||
# but skip the last state, as it already is normalized
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||
|
||||
return Logits(x, state, aux_loss, attentions, hidden_states, None)
|
||||
return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None)
|
||||
|
||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
def inputs(
|
||||
|
@ -1456,8 +1452,6 @@ class Base(nn.Module):
|
|||
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
|
||||
teacher = None,
|
||||
):
|
||||
# return early if it's "good" enough"
|
||||
# lambda because we need to capture the classifier_levels and mask
|
||||
|
@ -1503,7 +1497,7 @@ class Base(nn.Module):
|
|||
|
||||
# derive quant levels from inputs if not provided
|
||||
if quant_levels is None:
|
||||
quant_levels = self.get_input( inputs, "quant_level" )
|
||||
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
|
||||
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
|
||||
|
@ -1615,66 +1609,6 @@ class Base(nn.Module):
|
|||
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
||||
self.training_steps += 1 # batch_size
|
||||
|
||||
# get soft targets from teacher
|
||||
# required to do it in here because the batch is further processed within the model (because of per-model config)
|
||||
if teacher is not None:
|
||||
# grab the teacher's logits
|
||||
with torch.no_grad():
|
||||
teacher_output = teacher.forward_super(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
||||
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
||||
output_lens = [ 0 for _ in range(batch_size) ]
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
task_type = "tts"
|
||||
for name, input in batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
|
||||
for name, input in batch:
|
||||
if name == task_outputs.get(task_type, name):
|
||||
output_lens[batch_index] = input.shape[0]
|
||||
|
||||
# KD hyperparameters
|
||||
T = self.teacher_temperature
|
||||
A = self.teacher_alpha
|
||||
|
||||
# create probability distributions (literature says to have the students already log'd but not the teacher)
|
||||
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ]
|
||||
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
|
||||
|
||||
# filter out logits that are / would inf
|
||||
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
||||
for batch_index, output_len in enumerate( output_lens ):
|
||||
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
|
||||
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
|
||||
|
||||
mask = mask_a | mask_b
|
||||
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
|
||||
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
|
||||
|
||||
#soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
#soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size
|
||||
|
||||
"""
|
||||
# flatten to a single sequence of token-probabilities
|
||||
# but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025)
|
||||
student_probs = torch.concat( student_probs, dim = 0 )
|
||||
teacher_probs = torch.concat( teacher_probs, dim = 0 )
|
||||
soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size
|
||||
"""
|
||||
|
||||
# mix if not nan
|
||||
if not torch.isnan(soft_loss).any():
|
||||
for k in loss.keys():
|
||||
loss[k] *= (1.0 - A)
|
||||
loss['kl'] = soft_loss * A
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
loss["aux_loss"] = output.loss
|
||||
|
@ -1683,7 +1617,7 @@ class Base(nn.Module):
|
|||
self.stats = stats
|
||||
|
||||
# rewrap, because we're modifying the logits here
|
||||
return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer)
|
||||
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
|
@ -31,7 +31,7 @@ def train_feeder(engine, batch, teacher=None):
|
|||
batch_size = len(batch["text"])
|
||||
engine.current_batch_size = batch_size
|
||||
|
||||
engine(
|
||||
output = engine(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
resps_list=batch["resps"],
|
||||
|
@ -40,9 +40,75 @@ def train_feeder(engine, batch, teacher=None):
|
|||
task_list=batch["task"],
|
||||
|
||||
training=True,
|
||||
teacher=teacher,
|
||||
)
|
||||
|
||||
# get soft targets from teacher
|
||||
if teacher is not None:
|
||||
# extract inputs forwarded to model
|
||||
inputs = output.inputs
|
||||
|
||||
# grab the teacher's logits
|
||||
with torch.no_grad():
|
||||
teacher_output = teacher.forward_super(
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# KD hyperparameters
|
||||
T = cfg.hyperparameters.teacher_temperature
|
||||
A = cfg.hyperparameters.teacher_alpha
|
||||
L = cfg.hyperparameters.teacher_loss_fn
|
||||
|
||||
# I don't know what to call the last one
|
||||
if L not in ["kl", "mse"]:
|
||||
L = "kd"
|
||||
|
||||
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
||||
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
||||
if not engine.module.ignore_inputs_for_loss:
|
||||
student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ]
|
||||
teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ]
|
||||
else:
|
||||
task_outputs = {
|
||||
"tts": "resp",
|
||||
"stt": "text",
|
||||
"len": "len",
|
||||
}
|
||||
output_lens = [ 0 for _ in range(batch_size) ]
|
||||
for batch_index, _batch in enumerate(inputs):
|
||||
task_type = "tts"
|
||||
for name, input in _batch:
|
||||
if name == "task":
|
||||
task_type = input
|
||||
|
||||
for name, input in _batch:
|
||||
if name == task_outputs.get(task_type, name):
|
||||
output_lens[batch_index] = input.shape[0]
|
||||
|
||||
# create probability distributions (literature says to have the students already log'd but not the teacher)
|
||||
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ]
|
||||
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
|
||||
|
||||
# filter out logits that are / would inf
|
||||
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
||||
for batch_index in range( batch_size ):
|
||||
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
|
||||
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
|
||||
|
||||
mask = mask_a | mask_b
|
||||
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
|
||||
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
|
||||
|
||||
if L == "kl":
|
||||
soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
elif L == "mse":
|
||||
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
else:
|
||||
soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||
|
||||
for k in engine.module.loss.keys():
|
||||
engine.module.loss[k] *= (1.0 - A)
|
||||
engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size
|
||||
|
||||
losses = engine.gather_attribute("loss")
|
||||
stat = engine.gather_attribute("stats")
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ def train(
|
|||
# validate if there's at least one model to train
|
||||
found = False
|
||||
for name, engine in engines.items():
|
||||
if engine.training:
|
||||
if engine._training:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
|
|
Loading…
Reference in New Issue
Block a user