diff --git a/vall_e/config.py b/vall_e/config.py index f972082..f520dde 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -280,9 +280,6 @@ class ModelExperimentalSettings: layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities layerskip_e_scale: float = 0.2 # early-exit loss scalar value - teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation - teacher_temperature: float = 1.0 - # I really need to clean this up @dataclass() class Model: @@ -454,7 +451,7 @@ class LoRA: if not self.rvq_levels: return True return level in self.rvq_levels - + @dataclass() class Hyperparameters: batch_size: int = 8 # number of samples per training batch @@ -476,6 +473,10 @@ class Hyperparameters: torch_optimizer: bool = False # if the requested optimizer is torch-derived rather than deepspeed supplied torch_scheduler: bool = False # if the requested scheduler is torch-derived rather than deepspeed-supplied + + teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation + teacher_temperature: float = 1.0 + teacher_loss_fn: str = "kl" # kl | mse @dataclass() class Evaluation: diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 21945b2..a771e81 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -48,8 +48,6 @@ class AR_NAR(Base): lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - - teacher = None, ): # deduce batch_size if text_list is not None: @@ -204,7 +202,6 @@ class AR_NAR(Base): return super().forward( inputs=inputs, quant_levels=quant_levels, - teacher=teacher, ) def forward_nar_masked( @@ -842,7 +839,6 @@ class AR_NAR(Base): len_list: list[Tensor] | None = None, training: bool | None = None, - teacher = None, disable_tqdm=False, use_lora=None, @@ -879,8 +875,6 @@ class AR_NAR(Base): lang_list=lang_list, tone_list=tone_list, len_list=len_list, - - teacher=teacher, ) # is NAR diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cded4ec..0b983b6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding from ..data import get_task_symmap # these seem more elegant than a dict -Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer']) +Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer']) Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) @@ -442,8 +442,6 @@ class Base(nn.Module): unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True interleave = self.config.experimental.interleave if self.config is not None else False noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False - teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5 - teacher_temperature = self.config.experimental.teacher_temperature if self.config is not None else 0.5 masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False @@ -493,8 +491,6 @@ class Base(nn.Module): self.masking_ratio = masking_ratio self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks - self.teacher_alpha = teacher_alpha - self.teacher_temperature = teacher_temperature # use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends """ @@ -891,7 +887,7 @@ class Base(nn.Module): # but skip the last state, as it already is normalized hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ] - return Logits(x, state, aux_loss, attentions, hidden_states, None) + return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None) # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( @@ -1456,8 +1452,6 @@ class Base(nn.Module): output_attentions: bool = False, output_hidden_states: bool = False, - - teacher = None, ): # return early if it's "good" enough" # lambda because we need to capture the classifier_levels and mask @@ -1503,7 +1497,7 @@ class Base(nn.Module): # derive quant levels from inputs if not provided if quant_levels is None: - quant_levels = self.get_input( inputs, "quant_level" ) + quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ] x_list = self.inputs_to_embeddings( inputs, quant_levels ) @@ -1615,66 +1609,6 @@ class Base(nn.Module): # to-do: instead make the cirriculum rely on samples processed instead of steps self.training_steps += 1 # batch_size - # get soft targets from teacher - # required to do it in here because the batch is further processed within the model (because of per-model config) - if teacher is not None: - # grab the teacher's logits - with torch.no_grad(): - teacher_output = teacher.forward_super( - inputs=inputs, - quant_levels=quant_levels, - ) - - # determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways) - # we could recreate the target sequence with the ignore indices put in, but that's agony - output_lens = [ 0 for _ in range(batch_size) ] - for batch_index, batch in enumerate(inputs): - task_type = "tts" - for name, input in batch: - if name == "task": - task_type = input - - for name, input in batch: - if name == task_outputs.get(task_type, name): - output_lens[batch_index] = input.shape[0] - - # KD hyperparameters - T = self.teacher_temperature - A = self.teacher_alpha - - # create probability distributions (literature says to have the students already log'd but not the teacher) - student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( logits, output_lens ) ] - teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ] - - # filter out logits that are / would inf - # this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier) - for batch_index, output_len in enumerate( output_lens ): - mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf - mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf - - mask = mask_a | mask_b - student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask ) - teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask ) - - #soft_losses = [ F.kl_div( student, teacher, reduction='mean' ) for student, teacher in zip( student_probs, teacher_probs ) ] - #soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ] - soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ] - soft_loss = torch.stack([*soft_losses]).sum() * (T ** 2) / batch_size - - """ - # flatten to a single sequence of token-probabilities - # but this shouldn't actually work because some logits might be (..., 1024) and some might be (..., 1025) - student_probs = torch.concat( student_probs, dim = 0 ) - teacher_probs = torch.concat( teacher_probs, dim = 0 ) - soft_loss = F.mse_loss( student_probs, teacher_probs ) * (T ** 2) / batch_size - """ - - # mix if not nan - if not torch.isnan(soft_loss).any(): - for k in loss.keys(): - loss[k] *= (1.0 - A) - loss['kl'] = soft_loss * A - # include any additional losses (for example: MoE router) if output.loss is not None: loss["aux_loss"] = output.loss @@ -1683,7 +1617,7 @@ class Base(nn.Module): self.stats = stats # rewrap, because we're modifying the logits here - return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer) + return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer) def sample( self, diff --git a/vall_e/train.py b/vall_e/train.py index 1c13ff8..03ba1f9 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -31,7 +31,7 @@ def train_feeder(engine, batch, teacher=None): batch_size = len(batch["text"]) engine.current_batch_size = batch_size - engine( + output = engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"], @@ -40,9 +40,75 @@ def train_feeder(engine, batch, teacher=None): task_list=batch["task"], training=True, - teacher=teacher, ) + # get soft targets from teacher + if teacher is not None: + # extract inputs forwarded to model + inputs = output.inputs + + # grab the teacher's logits + with torch.no_grad(): + teacher_output = teacher.forward_super( + inputs=inputs, + ) + + # KD hyperparameters + T = cfg.hyperparameters.teacher_temperature + A = cfg.hyperparameters.teacher_alpha + L = cfg.hyperparameters.teacher_loss_fn + + # I don't know what to call the last one + if L not in ["kl", "mse"]: + L = "kd" + + # determine the output length for each batch (because blah blah some embeddings don't map to a discrete token anyways) + # we could recreate the target sequence with the ignore indices put in, but that's agony + if not engine.module.ignore_inputs_for_loss: + student_probs = [ F.log_softmax( student / T, dim=-1 ) for student in output.logits ] + teacher_probs = [ F.softmax( teacher / T, dim=-1 ) for teacher in teacher_output.logits ] + else: + task_outputs = { + "tts": "resp", + "stt": "text", + "len": "len", + } + output_lens = [ 0 for _ in range(batch_size) ] + for batch_index, _batch in enumerate(inputs): + task_type = "tts" + for name, input in _batch: + if name == "task": + task_type = input + + for name, input in _batch: + if name == task_outputs.get(task_type, name): + output_lens[batch_index] = input.shape[0] + + # create probability distributions (literature says to have the students already log'd but not the teacher) + student_probs = [ F.log_softmax( student[-l:] / T, dim=-1 ) for student, l in zip( output.logits, output_lens ) ] + teacher_probs = [ F.softmax( teacher[-l:] / T, dim=-1 ) for teacher, l in zip( teacher_output.logits, output_lens ) ] + + # filter out logits that are / would inf + # this causes problems when computing the loss if there's any inherently never-ever probabilities (for example, NAR RVQ-0 demasking for the stop token, because I did not clip it from the classifier) + for batch_index in range( batch_size ): + mask_a = student_probs[batch_index] == -float("inf") # log(0) = -inf + mask_b = teacher_probs[batch_index] == 0.0 # this gets log'd, eventually creating -inf + + mask = mask_a | mask_b + student_probs[batch_index] = torch.masked_select( student_probs[batch_index], ~mask ) + teacher_probs[batch_index] = torch.masked_select( teacher_probs[batch_index], ~mask ) + + if L == "kl": + soft_losses = [ F.kl_div( student, teacher, reduction='sum' ) for student, teacher in zip( student_probs, teacher_probs ) ] + elif L == "mse": + soft_losses = [ F.mse_loss( student, teacher ) for student, teacher in zip( student_probs, teacher_probs ) ] + else: + soft_losses = [ torch.sum(teacher * (teacher.log() - student)) for student, teacher in zip( student_probs, teacher_probs ) ] + + for k in engine.module.loss.keys(): + engine.module.loss[k] *= (1.0 - A) + engine.module.loss[L] = torch.stack([*soft_losses]).sum() * A * (T ** 2) / batch_size + losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 5cb2b10..ef49f61 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -143,7 +143,7 @@ def train( # validate if there's at least one model to train found = False for name, engine in engines.items(): - if engine.training: + if engine._training: found = True break if not found: