agnostified KD
This commit is contained in:
parent
953d3eb030
commit
34a66e1052
|
@ -280,9 +280,6 @@ class ModelExperimentalSettings:
|
||||||
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities
|
||||||
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
layerskip_e_scale: float = 0.2 # early-exit loss scalar value
|
||||||
|
|
||||||
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
|
||||||
teacher_temperature: float = 1.0
|
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
|
@ -454,7 +451,7 @@ class LoRA:
|
||||||
if not self.rvq_levels:
|
if not self.rvq_levels:
|
||||||
return True
|
return True
|
||||||
return level in self.rvq_levels
|
return level in self.rvq_levels
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Hyperparameters:
|
class Hyperparameters:
|
||||||
batch_size: int = 8 # number of samples per training batch
|
batch_size: int = 8 # number of samples per training batch
|
||||||
|
@ -476,6 +473,10 @@ class Hyperparameters:
|
||||||
|
|
||||||
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied
|
||||||
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied
|
||||||
|
|
||||||
|
teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation
|
||||||
|
teacher_temperature: float = 1.0
|
||||||
|
teacher_loss_fn: str = "kl" # kl | mse
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Evaluation:
|
class Evaluation:
|
||||||
|
|
|
@ -48,8 +48,6 @@ class AR_NAR(Base):
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
teacher = None,
|
|
||||||
):
|
):
|
||||||
# deduce batch_size
|
# deduce batch_size
|
||||||
if text_list is not None:
|
if text_list is not None:
|
||||||
|
@ -204,7 +202,6 @@ class AR_NAR(Base):
|
||||||
return super().forward(
|
return super().forward(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
teacher=teacher,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_nar_masked(
|
def forward_nar_masked(
|
||||||
|
@ -842,7 +839,6 @@ class AR_NAR(Base):
|
||||||
len_list: list[Tensor] | None = None,
|
len_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
training: bool | None = None,
|
training: bool | None = None,
|
||||||
teacher = None,
|
|
||||||
|
|
||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
use_lora=None,
|
use_lora=None,
|
||||||
|
@ -879,8 +875,6 @@ class AR_NAR(Base):
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
|
|
||||||
teacher=teacher,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# is NAR
|
# is NAR
|
||||||
|
|
|
@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding
|
||||||
from ..data import get_task_symmap
|
from ..data import get_task_symmap
|
||||||
|
|
||||||
# these seem more elegant than a dict
|
# these seem more elegant than a dict
|
||||||
Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
|
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||||
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
|
||||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||||
|
|
||||||
|
@ -442,8 +442,6 @@ class Base(nn.Module):
|
||||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||||
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
|
||||||
teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5
|
|
||||||
teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5
|
|
||||||
|
|
||||||
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
|
||||||
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
|
||||||
|
@ -493,8 +491,6 @@ class Base(nn.Module):
|
||||||
self.masking_ratio = masking_ratio
|
self.masking_ratio = masking_ratio
|
||||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||||
self.noncausal_masks = noncausal_masks
|
self.noncausal_masks = noncausal_masks
|
||||||
self.teacher_alpha = teacher_alpha
|
|
||||||
self.teacher_temperature = teacher_temperature
|
|
||||||
|
|
||||||
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
# use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends
|
||||||
"""
|
"""
|
||||||
|
@ -891,7 +887,7 @@ class Base(nn.Module):
|
||||||
# but skip the last state, as it already is normalized
|
# but skip the last state, as it already is normalized
|
||||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
|
||||||
|
|
||||||
return Logits(x, state, aux_loss, attentions, hidden_states, None)
|
return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None)
|
||||||
|
|
||||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||||
def inputs(
|
def inputs(
|
||||||
|
@ -1456,8 +1452,6 @@ class Base(nn.Module):
|
||||||
|
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
|
|
||||||
teacher = None,
|
|
||||||
):
|
):
|
||||||
# return early if it's "good" enough"
|
# return early if it's "good" enough"
|
||||||
# lambda because we need to capture the classifier_levels and mask
|
# lambda because we need to capture the classifier_levels and mask
|
||||||
|
@ -1503,7 +1497,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# derive quant levels from inputs if not provided
|
# derive quant levels from inputs if not provided
|
||||||
if quant_levels is None:
|
if quant_levels is None:
|
||||||
quant_levels = self.get_input( inputs, "quant_level" )
|
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
|
||||||
|
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
|
|
||||||
|
@ -1615,66 +1609,6 @@ class Base(nn.Module):
|
||||||
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
# to-do: instead make the cirriculum rely on samples processed instead of steps
|
||||||
self.training_steps += 1 # batch_size
|
self.training_steps += 1 # batch_size
|
||||||
|
|
||||||
# get soft targets from teacher
|
|
||||||
# required to do it in here because the batch is further processed within the model (because of per-model config)
|
|
||||||
if teacher is not None:
|
|
||||||
# grab the teacher's logits
|
|
||||||
with torch.no_grad():
|
|
||||||
teacher_output = teacher.forward_super(
|
|
||||||
inputs=inputs,
|
|
||||||
quant_levels=quant_levels,
|
|
||||||
)
|
|
||||||
|
|
||||||
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
|
||||||
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
|
||||||
output_lens = [ 0 for _ in range(batch_size) ]
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
|
||||||
task_type = "tts"
|
|
||||||
for name, input in batch:
|
|
||||||
if name == "task":
|
|
||||||
task_type = input
|
|
||||||
|
|
||||||
for name, input in batch:
|
|
||||||
if name == task_outputs.get(task_type, name):
|
|
||||||
output_lens[batch_index] = input.shape[0]
|
|
||||||
|
|
||||||
# KD hyperparameters
|
|
||||||
T = self.teacher_temperature
|
|
||||||
A = self.teacher_alpha
|
|
||||||
|
|
||||||
# create probability distributions (literature says to have the students already log'd but not the teacher)
|
|
||||||
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ]
|
|
||||||
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
|
|
||||||
|
|
||||||
# filter out logits that are / would inf
|
|
||||||
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
|
||||||
for batch_index, output_len in enumerate( output_lens ):
|
|
||||||
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
|
|
||||||
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
|
|
||||||
|
|
||||||
mask = mask_a | mask_b
|
|
||||||
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
|
|
||||||
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
|
|
||||||
|
|
||||||
#soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
|
||||||
#soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
|
|
||||||
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
|
||||||
soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size
|
|
||||||
|
|
||||||
"""
|
|
||||||
# flatten to a single sequence of token-probabilities
|
|
||||||
# but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025)
|
|
||||||
student_probs = torch.concat( student_probs, dim = 0 )
|
|
||||||
teacher_probs = torch.concat( teacher_probs, dim = 0 )
|
|
||||||
soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size
|
|
||||||
"""
|
|
||||||
|
|
||||||
# mix if not nan
|
|
||||||
if not torch.isnan(soft_loss).any():
|
|
||||||
for k in loss.keys():
|
|
||||||
loss[k] *= (1.0 - A)
|
|
||||||
loss['kl'] = soft_loss * A
|
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
loss["aux_loss"] = output.loss
|
loss["aux_loss"] = output.loss
|
||||||
|
@ -1683,7 +1617,7 @@ class Base(nn.Module):
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
|
|
||||||
# rewrap, because we're modifying the logits here
|
# rewrap, because we're modifying the logits here
|
||||||
return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer)
|
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -31,7 +31,7 @@ def train_feeder(engine, batch, teacher=None):
|
||||||
batch_size = len(batch["text"])
|
batch_size = len(batch["text"])
|
||||||
engine.current_batch_size = batch_size
|
engine.current_batch_size = batch_size
|
||||||
|
|
||||||
engine(
|
output = engine(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
resps_list=batch["resps"],
|
resps_list=batch["resps"],
|
||||||
|
@ -40,9 +40,75 @@ def train_feeder(engine, batch, teacher=None):
|
||||||
task_list=batch["task"],
|
task_list=batch["task"],
|
||||||
|
|
||||||
training=True,
|
training=True,
|
||||||
teacher=teacher,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get soft targets from teacher
|
||||||
|
if teacher is not None:
|
||||||
|
# extract inputs forwarded to model
|
||||||
|
inputs = output.inputs
|
||||||
|
|
||||||
|
# grab the teacher's logits
|
||||||
|
with torch.no_grad():
|
||||||
|
teacher_output = teacher.forward_super(
|
||||||
|
inputs=inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# KD hyperparameters
|
||||||
|
T = cfg.hyperparameters.teacher_temperature
|
||||||
|
A = cfg.hyperparameters.teacher_alpha
|
||||||
|
L = cfg.hyperparameters.teacher_loss_fn
|
||||||
|
|
||||||
|
# I don't know what to call the last one
|
||||||
|
if L not in ["kl", "mse"]:
|
||||||
|
L = "kd"
|
||||||
|
|
||||||
|
# determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways)
|
||||||
|
# we could recreate the target sequence with the ignore indices put in, but that's agony
|
||||||
|
if not engine.module.ignore_inputs_for_loss:
|
||||||
|
student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ]
|
||||||
|
teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ]
|
||||||
|
else:
|
||||||
|
task_outputs = {
|
||||||
|
"tts": "resp",
|
||||||
|
"stt": "text",
|
||||||
|
"len": "len",
|
||||||
|
}
|
||||||
|
output_lens = [ 0 for _ in range(batch_size) ]
|
||||||
|
for batch_index, _batch in enumerate(inputs):
|
||||||
|
task_type = "tts"
|
||||||
|
for name, input in _batch:
|
||||||
|
if name == "task":
|
||||||
|
task_type = input
|
||||||
|
|
||||||
|
for name, input in _batch:
|
||||||
|
if name == task_outputs.get(task_type, name):
|
||||||
|
output_lens[batch_index] = input.shape[0]
|
||||||
|
|
||||||
|
# create probability distributions (literature says to have the students already log'd but not the teacher)
|
||||||
|
student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ]
|
||||||
|
teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ]
|
||||||
|
|
||||||
|
# filter out logits that are / would inf
|
||||||
|
# this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier)
|
||||||
|
for batch_index in range( batch_size ):
|
||||||
|
mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf
|
||||||
|
mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf
|
||||||
|
|
||||||
|
mask = mask_a | mask_b
|
||||||
|
student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask )
|
||||||
|
teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask )
|
||||||
|
|
||||||
|
if L == "kl":
|
||||||
|
soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||||
|
elif L == "mse":
|
||||||
|
soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||||
|
else:
|
||||||
|
soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ]
|
||||||
|
|
||||||
|
for k in engine.module.loss.keys():
|
||||||
|
engine.module.loss[k] *= (1.0 - A)
|
||||||
|
engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size
|
||||||
|
|
||||||
losses = engine.gather_attribute("loss")
|
losses = engine.gather_attribute("loss")
|
||||||
stat = engine.gather_attribute("stats")
|
stat = engine.gather_attribute("stats")
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,7 @@ def train(
|
||||||
# validate if there's at least one model to train
|
# validate if there's at least one model to train
|
||||||
found = False
|
found = False
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if engine.training:
|
if engine._training:
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
if not found:
|
if not found:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user