From 23d402bf011107bcb9933baace169987d4381193 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 5 Dec 2024 23:05:52 -0600 Subject: [PATCH] added knowledge distillation in the trainer (sadly it is not agnostic because of the grave mistake of further processing the batch within the forward pass, so subsequent calls do not match......) --- vall_e/config.py | 14 ++++- vall_e/engines/__init__.py | 8 ++- vall_e/engines/base.py | 26 ++++++-- vall_e/engines/deepspeed.py | 4 ++ vall_e/models/ar_nar.py | 14 ++++- vall_e/models/base.py | 119 +++++++++++++++++++++++++----------- vall_e/train.py | 3 +- 7 files changed, 142 insertions(+), 46 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index e491b42..836d29b 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -115,7 +115,8 @@ class BaseConfig: raise Exception(f'Model path does not exist: {model_path}') # load state dict and copy its stored model config - model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path, "attention": "auto" } ] if model_path and model_path.exists() else [] + model_kwargs = { "attention": "auto", "training": False, "teacher": False } + model_state_dict = [ torch_load( model_path )["config"] | { "path": model_path } | model_kwargs ] if model_path and model_path.exists() else [] lora_state_dict = [ torch_load( lora_path )["config"] | { "path": lora_path } ] if lora_path and lora_path.exists() else [] state = { "models": model_state_dict, "loras": lora_state_dict, "trainer": { "load_state_dict": True } } @@ -279,6 +280,8 @@ class ModelExperimentalSettings: layerskip_p_max: float = 0.1 # maximum probabilty to dropout the last layer, used for calculating layer dropout probabilities layerskip_e_scale: float = 0.2 # early-exit loss scalar value + teacher_alpha: float = 0.5 # mixing factor when performing knowledge distillation + # I really need to clean this up @dataclass() class Model: @@ -291,7 +294,9 @@ class Model: tones: int = 1 # defined tones (unsued) experts: int = 1 # for mixtral / retnet-ts arch_type: str = "llama" # underling LM architecture used - training: bool = True # I really need to attend to this + training: bool = False # I really need to attend to this + teacher: bool = False # if this is to be treated as a teacher + frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training attention: str = "auto" # for llama arch_types: attention used dropout: float = 0.1 # adjustable dropout value @@ -1006,6 +1011,11 @@ class Config(BaseConfig): if isinstance( model.experimental, dict ): model.experimental = ModelExperimentalSettings(**model.experimental) + if model.teacher: + model.training = False + if model.training: + model.teacher = False + if self.hyperparameters.scheduler_type and not self.hyperparameters.scheduler: self.hyperparameters.scheduler = self.hyperparameters.scheduler_type self.hyperparameters.scheduler_type = "" diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index f6cbd99..3685035 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -46,7 +46,7 @@ def load_engines(training=True, **model_kwargs): stats = None lora = None - inferencing = cfg.mode == "inferencing" or not model.config.training or not training + inferencing = cfg.mode == "inferencing" or not model.config.training or not training or model.config.teacher backend = cfg.inference.backend if inferencing else cfg.trainer.backend loads_state_dict = cfg.trainer.load_state_dict # or inferencing @@ -327,6 +327,12 @@ def load_engines(training=True, **model_kwargs): if cfg.optimizations.model_offloading: engine.module = ml.offload_model( engine.module, policy=cfg.optimizations.model_offloading ) + # set to train/eval + if engine.hyper_config.training: + engine.module.train() + else: + engine.module.eval() + # setup wandb if engine._training and cfg.trainer.wandb and wandb is not None: key_name = name diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index f102d31..270b477 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -104,6 +104,12 @@ class Engine(): return True return self.hyper_config.training + @property + def _teacher(self): + if not hasattr(self, "hyper_config"): + return False + return self.hyper_config.teacher + @property def global_step(self): return self.global_steps @@ -352,6 +358,9 @@ class Engines(dict[str, Engine]): lora, module = lora_get_state_dict( module, split = True ) save_path = cfg.ckpt_dir / cfg.lora.full_name / f"{cfg.weights_name}.{format}" + config_dict = dict(**config.__dict__) + config_dict |= {"experimental": config.experimental.__dict__} + state_dict = { 'module': module, 'lora': lora, @@ -362,7 +371,7 @@ class Engines(dict[str, Engine]): "tokens_processed": engine.tokens_processed, }, "userdata": userdata, - "config": config.__dict__ | {"experimental": config.experimental.__dict__} # i hate implicit aliasing rules + "config": config_dict } if lora is None: @@ -478,8 +487,17 @@ class Engines(dict[str, Engine]): if cfg.trainer.gc_mode == 'step': do_gc() + # preiterate to get teacher + teacher = None for name, engine in self.items(): - if not engine._training: + if not engine._teacher: + continue + teacher = engine.module + break + + for name, engine in self.items(): + # only models that we're training + if not engine._training or engine._teacher: continue device = engine.device @@ -493,10 +511,10 @@ class Engines(dict[str, Engine]): n_ooms = torch.zeros([], device=device) if not cfg.trainer.check_for_oom: - res = feeder( engine=engine, batch=batch ) + res = feeder( engine=engine, batch=batch, teacher=teacher ) else: try: - res = feeder( engine=engine, batch=batch ) + res = feeder( engine=engine, batch=batch, teacher=teacher ) except RuntimeError as e: _logger.error(f"Forward: {str(e)}") diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 04809f2..4edc001 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -92,6 +92,10 @@ class Engine(DeepSpeedEngine): def _training(self): return self.hyper_config.training + @property + def _teacher(self): + return self.hyper_config.teacher + @property def global_step(self): return self.global_steps diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 83111dc..21945b2 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -32,6 +32,10 @@ from ..samplers import cfg_logits text_task = [ "stt" ] class AR_NAR(Base): + # yikes + def forward_super(self, *args, **kwargs): + return super().forward(*args, **kwargs) + # parse inputs for training # a lot of this could be delegated back to the dataloader, but it's just easier to keep the task of the dataloader to provide sufficient data, and the model to process the data for training def forward_train( @@ -44,6 +48,8 @@ class AR_NAR(Base): lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, + + teacher = None, ): # deduce batch_size if text_list is not None: @@ -198,6 +204,7 @@ class AR_NAR(Base): return super().forward( inputs=inputs, quant_levels=quant_levels, + teacher=teacher, ) def forward_nar_masked( @@ -834,7 +841,8 @@ class AR_NAR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, - training: bool | int | None = None, + training: bool | None = None, + teacher = None, disable_tqdm=False, use_lora=None, @@ -871,8 +879,8 @@ class AR_NAR(Base): lang_list=lang_list, tone_list=tone_list, len_list=len_list, - disable_tqdm=disable_tqdm, - use_lora=use_lora, + + teacher=teacher, ) # is NAR diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b08e25c..ca67fee 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -38,7 +38,7 @@ from ..emb.qnt import encode_as_embedding from ..data import get_task_symmap # these seem more elegant than a dict -Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer']) +Logits = namedtuple('Logits', ['logits', 'state', 'loss', 'attentions', 'hidden_states', 'exited_layer']) Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) @@ -389,12 +389,13 @@ class Base(nn.Module): l_padding: int = 0, - training = True, + training = True, attention = None, config = None, ): super().__init__() self.training = training + self.teaching = False self.config = config self.n_text_tokens = n_text_tokens @@ -428,6 +429,11 @@ class Base(nn.Module): if not attention: attention = self.config.attention if self.config is not None else "auto" + # crunge + if self.config is not None and config.teacher: + self.teaching = True + self.training = False + attention_backend = attention audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False @@ -436,6 +442,7 @@ class Base(nn.Module): unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True interleave = self.config.experimental.interleave if self.config is not None else False noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False + teacher_alpha = self.config.experimental.teacher_alpha if self.config is not None else 0.5 masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False @@ -485,6 +492,7 @@ class Base(nn.Module): self.masking_ratio = masking_ratio self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks + self.teacher_alpha = teacher_alpha # use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends """ @@ -1265,6 +1273,8 @@ class Base(nn.Module): logits, quant_levels: list[int] | None = None, + compute_hard_loss = True, + compute_acc = True, ): loss = {} stats = {} @@ -1297,6 +1307,7 @@ class Base(nn.Module): task_type = "tts" dropout_mask = None classifier_level = None + output_len = 0 for name, input in batch: if name == "task": @@ -1354,8 +1365,11 @@ class Base(nn.Module): it += seq_len + 1 # +1 to incorporate the separator # deduce if a name for a task is an input or output - if self.ignore_inputs_for_loss and name != task_outputs.get(task_type, name): - ignored = True + if name != task_outputs.get(task_type, name): + if self.ignore_inputs_for_loss: + ignored = True + else: + output_len = seq_len if ignored: # pruned @@ -1378,20 +1392,20 @@ class Base(nn.Module): logit = logit[..., :-l, :] token = token[..., l:] # shift sequence to the right by one (or causal chunk size) - if f'{name}.nll' not in loss: - loss[f'{name}.nll'] = [] + if compute_hard_loss: + nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor + if f'{name}.nll' not in loss: + loss[f'{name}.nll'] = [] + loss[f'{name}.nll'].append( nll ) - if f'{name}.acc' not in stats: - stats[f'{name}.acc'] = [] - - nll = F.cross_entropy( logit, token.long(), ignore_index=self.ignore_index ) * loss_factor - if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) - else: - metrics = self.accuracy_metric( logit, token ) - - loss[f'{name}.nll'].append( nll ) - stats[f'{name}.acc'].append( metrics ) + if compute_acc: + if self.metrics is not None: + metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) + else: + metrics = self.accuracy_metric( logit, token ) + if f'{name}.acc' not in stats: + stats[f'{name}.acc'] = [] + stats[f'{name}.acc'].append( metrics ) # add to list else: target.append( token ) @@ -1407,21 +1421,21 @@ class Base(nn.Module): logit = logit[..., :-l, :] # shift the target so that token n... target = target[..., l:] # ...predicts token n + 1 - nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index ) + if compute_hard_loss: + nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index ) + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"].append( nll ) - if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) ) - else: - metrics = self.accuracy_metric( logit, target ) + if compute_acc: + if self.metrics is not None: + metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) ) + else: + metrics = self.accuracy_metric( logit, target ) - if 'nll' not in loss: - loss['nll'] = [] - - if 'acc' not in stats: - stats['acc'] = [] - - loss["nll"].append( nll ) - stats["acc"].append( metrics ) + if 'acc' not in stats: + stats['acc'] = [] + stats["acc"].append( metrics ) # average loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } @@ -1440,6 +1454,8 @@ class Base(nn.Module): output_attentions: bool = False, output_hidden_states: bool = False, + + teacher = None, ): # return early if it's "good" enough" # lambda because we need to capture the classifier_levels and mask @@ -1492,6 +1508,7 @@ class Base(nn.Module): x, mask = list_to_tensor(x_list) training = self.training + teaching = self.teaching device = x.device batch_size = len(x_list) @@ -1566,8 +1583,14 @@ class Base(nn.Module): hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] # compute loss if the target is given - if training: - loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) + if not training: + loss = None + stats = None + + self.loss = None + self.stats = None + else: + loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels, compute_hard_loss=training, compute_acc=training ) # compute it as an aux-loss if self.layerskip: @@ -1590,15 +1613,41 @@ class Base(nn.Module): # to-do: instead make the cirriculum rely on samples processed instead of steps self.training_steps += 1 # batch_size + # get soft targets from teacher + # it might be better to compute these once instead of per-engine, but realistically who is actually training multiple models + if teacher is not None: + with torch.no_grad(): + teacher_output = teacher.forward_super( + inputs=inputs, + quant_levels=quant_levels, + ) + + soft_loss = [ + F.kl_div( + F.log_softmax( student, dim=-1 ).unsqueeze(0), + F.softmax( teacher, dim=-1 ).unsqueeze(0), + reduction='batchmean' + ) + for student, teacher in zip( logits, teacher_output.logits ) + ] + soft_loss = torch.stack([*soft_loss]).sum() / batch_size + + # mix if not nan + if not torch.isnan(soft_loss).any(): + alpha = self.teacher_alpha + loss['kl'] = alpha * soft_loss + for k in loss.keys(): + loss[k] *= (1.0 - alpha) + # include any additional losses (for example: MoE router) - if output.aux_loss is not None: - loss["aux_loss"] = output.aux_loss + if output.loss is not None: + loss["aux_loss"] = output.loss self.loss = loss self.stats = stats # rewrap, because we're modifying the logits here - return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states, exited_layer) + return Logits(logits, output.state, loss, output.attentions, hidden_states, exited_layer) def sample( self, diff --git a/vall_e/train.py b/vall_e/train.py index 0bccc00..1c13ff8 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -26,7 +26,7 @@ _logger = logging.getLogger(__name__) mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu") -def train_feeder(engine, batch): +def train_feeder(engine, batch, teacher=None): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): batch_size = len(batch["text"]) engine.current_batch_size = batch_size @@ -40,6 +40,7 @@ def train_feeder(engine, batch): task_list=batch["task"], training=True, + teacher=teacher, ) losses = engine.gather_attribute("loss")