diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index 1e167c6..9b2c5fb 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -11,10 +11,12 @@ from ..config import cfg from ..data import fold_inputs, unfold_outputs import torch +import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from torch import Tensor from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint +from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision import random import math @@ -98,9 +100,9 @@ class Model(LlmArchClass): n_heads=16, p_dropout=0.1, - config = None, + config = cfg.model, ): - self.hyper_config = config + self.hyper_config = config hf_attention = config.attention if config is not None else None gradient_checkpointing = config.gradient_checkpointing if config is not None else True @@ -160,6 +162,14 @@ class Model(LlmArchClass): self.backbone.gradient_checkpointing = gradient_checkpointing + self.accuracy_metric = MulticlassAccuracy( + vocab_size, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=-100, + ) + def generate( self, *args, @@ -188,33 +198,52 @@ class Model(LlmArchClass): if "attention_mask" in kwargs: kwargs.pop("attention_mask") + labels = kwargs.pop("labels") if "labels" in kwargs else None + output = super().forward(*args, **kwargs) + logits = output.logits - if SELECTED_ARCH in ["llama", "retnet"]: - if output.loss is not None: - self.loss = dict( - nll = output.loss, - ) - elif SELECTED_ARCH in ["mamba","mamba2"]: - if "labels" in kwargs: - labels = kwargs.pop("labels") - logits = output.logits - - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - + # i HATE the correct way + if labels is not None: + if self.hyper_config is None or not self.hyper_config.loss_factors: + loss = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits, labels ) ]) self.loss = dict( nll = loss, ) + self.stats = dict( + acc = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits, labels ) ] ) / len( logits )).item() + ) + + else: + sep = 3 + # determine specific sections to focus on + indices = [ [ idx for idx, token in enumerate( batch ) if token == sep ] for i, batch in enumerate( labels ) ] + + text_index = 0 + resp_index = 1 # 1 indluces everything non text, -3 includes pre_resp + resp (ignores prom, probably better to include prom here) + + labels_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( labels ) ] + labels_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( labels ) ] + + logits_text = [ batch[:indices[i][text_index] + 1 ] for i, batch in enumerate( logits ) ] + logits_resp = [ batch[indices[i][resp_index] + 1:] for i, batch in enumerate( logits ) ] + + loss_text = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_text, labels_text ) ]) * self.hyper_config.loss_factor("text") + loss_resp = sum([ F.cross_entropy( logit[:-1, :], label[1:], ignore_index=-100 ) for logit, label in zip( logits_resp, labels_resp ) ]) * self.hyper_config.loss_factor("resp") + + self.loss = dict( + text = loss_text, + resp = loss_resp, + ) + + self.stats = dict( + acc = dict( + text = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_text, labels_text ) ] ) / len( logits_text )).item(), + resp = (sum([ self.accuracy_metric( logit, target ) for logit, target in zip( logits_resp, labels_resp ) ] ) / len( logits_resp )).item(), + ) + ) + return output def example_usage(): @@ -412,6 +441,7 @@ def example_usage(): target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels) stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) + stats |= engine.gather_attribute("stats") stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/train.py b/vall_e/train.py index 64a5a01..7909348 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -30,10 +30,10 @@ def train_feeder(engine, batch): batch_size = len(batch["text"]) if cfg.model.interleave: quant_levels = None - resps_list = [ resp for resp in resp_list ] + resps_list = [ resp for resp in batch["resps"] ] else: quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,)) - resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] + resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ] input_ids, attention_mask = fold_inputs( text_list=batch["text"],