From 319ca09a4f31e5ebb68d034e998192d61c85f245 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 12 Feb 2025 23:36:32 -0600 Subject: [PATCH] cleanup --- vall_e/models/ar_nar.py | 2 +- vall_e/models/base.py | 324 ++++++++++++---------------------------- 2 files changed, 95 insertions(+), 231 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index afd1e1c..8dd6a04 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -1254,7 +1254,7 @@ def example_usage(): available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) - steps = 750 // batch_size + steps = 500 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c401f08..21c9adb 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -322,6 +322,21 @@ class Classifiers(nn.Module): ] return torch.stack( xi ) +def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ): + """ + x = x.clone().detach().t() + for l, t in enumerate( x ): + x[l] = torch.where( dropout_mask, dropout_token, x[l] ) + return x.t() + """ + x = x.clone().detach() + levels = x.shape[-1] + for level in range( levels ): + lhs = dropout_token if not swapped else x[..., level] + rhs = x[..., level] if not swapped else dropout_token + x[..., level] = torch.where( dropout_mask, lhs, rhs ) + return x + # naively embeds each level of a codebook, then merges the embeddings with a Linear class AudioEncoder(nn.Module): def __init__( @@ -336,10 +351,7 @@ class AudioEncoder(nn.Module): def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: if dropout_mask is not None: - xi = xi.clone().detach().t() - for l, t in enumerate( xi ): - xi[l] = torch.where( dropout_mask, dropout_token, xi[l] ) - xi = xi.t() + xi = _dropout_codes( xi, dropout_mask, dropout_token ) x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1) x = self.proj(x) @@ -390,8 +402,10 @@ class AudioDecoder(nn.Module): def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: x = self.up( x ) + """ if self.transformer is not None: x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"] + """ x = self.down( x ) batch_size, seq_len, dim = x.shape @@ -1490,169 +1504,6 @@ class Base(nn.Module): return ids.to(device=device, dtype=torch.int32) - def calc_loss_new( - self, - inputs: list, - logits, - - compute_hard_loss = True, - compute_acc = True, - ): - loss = {} - stats = {} - - device = logits[0].device - batch_size = len(logits) - classifier_levels = self.get_input( inputs, "classifier_level" ) - - # handles tasks where the prompt has task tokens injected in the middle - def prompt_input_to_token( input ): - if isinstance(input, str): - return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) - return input - - for batch_index, batch in enumerate(inputs): - target = [] - causal = False - task_type = "tts" - dropout_mask = None - classifier_level = None - output_len = 0 - - for name, input in batch: - if name == "task": - task_type = input - elif name == "dropout_mask": - dropout_mask = input - elif name == "classifier_level": - classifier_level = input - - it = 0 - for name, input in batch: - token = None - ignored = False - - # non-tokened tasks - if name in non_tokened_names: - continue - # prom can either be a tensor itself or a list of tensors and strings - if name == "prom": - # expand to list if not a list - proms = [ input ] if isinstance(input, torch.Tensor) else input - # iterate over the list to inject their tokens - token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] ) - elif name == "resp": - # mask found, apply it - if dropout_mask is not None: - token = torch.where( dropout_mask, input.t(), self.ignore_index ).t() - else: - token = input - # not a special input, inject as-is - else: - token = input - - if not isinstance(token, torch.Tensor): - continue - - if token.is_floating_point(): - ignored = True - - # grab range of our logits for later - seq_len = token.shape[0] - start, end = it, it+seq_len - it += seq_len + 1 # +1 to incorporate the separator - - # deduce if a name for a task is an input or output - if name != task_outputs.get(task_type, name): - if self.ignore_inputs_for_loss: - ignored = True - # cringe - if task_type != "tts": - ignored = True - else: - output_len = seq_len - - if ignored: - # pruned - if self.config.loss_factors: - continue - # fill with ignored out tensor - token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16) - - # perform loss calculation on the individual piece - target.append( token ) - - if logits[batch_index].dim() != 3: - seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) - logit = logits[batch_index] - - # shift if causal - if causal: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - seq = seq[..., l:] # ...predicts token n + 1 - - if compute_hard_loss: - nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index ) - if 'nll' not in loss: - loss['nll'] = [] - loss["nll"].append( nll ) - - if compute_acc and False: - if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) - else: - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = 10, - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - metrics = accuracy_metric( logit, seq ) - - if 'acc' not in stats: - stats['acc'] = [] - stats["acc"].append( metrics ) - else: - for level, logit in enumerate( logits[batch_index] ): - seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) ) - - # shift if causal - if causal: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - seq = seq[..., l:] # ...predicts token n + 1 - - if compute_hard_loss: - nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index ) - if 'nll' not in loss: - loss['nll'] = [] - loss["nll"].append( nll ) - - if compute_acc and False: - if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) - else: - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = 10, - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - metrics = accuracy_metric( logit, seq ) - - if 'acc' not in stats: - stats['acc'] = [] - stats["acc"].append( metrics ) - - # average - loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } - stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() } - - return LossStats(loss, stats) - def calc_loss( self, inputs: list, @@ -1723,12 +1574,18 @@ class Base(nn.Module): token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) elif name == "resp": # mask found, apply it - if dropout_mask is not None: - # if mask use original token, else ignore - token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) - # use resps as-is - else: + if self.version < 7: token = input if input.dim() == 1 else input[:, quant_level] + + # mask found, apply it + if dropout_mask is not None: + token = torch.where( dropout_mask, token, self.ignore_index ) + else: + token = input + + # mask found, apply it + if dropout_mask is not None: + token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True ) # not a special input, inject as-is else: token = input @@ -1763,6 +1620,9 @@ class Base(nn.Module): # perform loss calculation on the individual piece if self.config.loss_factors: + # to-do: make this work with version >= 7 + assert self.version < 7, "Unsupported" + loss_factor = self.loss_factor(name) if loss_factor == 0.0: @@ -1782,7 +1642,7 @@ class Base(nn.Module): loss[f'{name}.nll'].append( nll ) if compute_acc: - if self.metrics is not None: + if self.metrics is not None and classifier_level in self.classifiers.names: metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) else: accuracy_metric = MulticlassAccuracy( @@ -1803,37 +1663,43 @@ class Base(nn.Module): # perofrm loss calculation on the entire sequence if not self.config.loss_factors: - target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) - logit = logits[batch_index] + def _calc_loss( logit, input ): + sequence = _join( input, torch.tensor(self.ignore_index, device=input[-1].device) ) - # shift if causal - if causal: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - target = target[..., l:] # ...predicts token n + 1 + # shift if causal + if causal: + l = self.causal_size + logit = logit[..., :-l, :] # shift the target so that token n... + sequence = sequence[..., l:] # ...predicts token n + 1 - if compute_hard_loss: - nll = F.cross_entropy( logit, target, ignore_index=self.ignore_index ) - if 'nll' not in loss: - loss['nll'] = [] - loss["nll"].append( nll ) + if compute_hard_loss: + nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index ) + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"].append( nll ) - if compute_acc: - if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ target ], self.classifiers.indices([ classifier_level ]) ) - else: - accuracy_metric = MulticlassAccuracy( - logit.shape[-1], - top_k = 10, - average="micro", - multidim_average="global", - ignore_index = -100 - ).to(logit.device) - metrics = accuracy_metric( logit, target ) + if compute_acc: + if self.metrics is not None and classifier_level in self.classifiers.names: + metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) ) + else: + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = 10, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + metrics = accuracy_metric( logit, sequence ) - if 'acc' not in stats: - stats['acc'] = [] - stats["acc"].append( metrics ) + if 'acc' not in stats: + stats['acc'] = [] + stats["acc"].append( metrics ) + + if logits[batch_index].dim() < 3: + _calc_loss( logits[batch_index], target ) + else: + for level, logit in enumerate( logits[batch_index] ): + _calc_loss( logit, [ x if x.dim() <= 1 else x[:, level] for x in target ] ) # average loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } @@ -1904,30 +1770,36 @@ class Base(nn.Module): logits = output.logits hidden_states = output.hidden_states - logits = [ logit for logit in logits ] - + # split between the two logit tasks, as audio logits become expanded if self.version >= 7: - p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ] - if p_indices: - p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0) + logits = [ logit for logit in logits ] - p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0) - p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0) - p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ] + audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ] + + decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ] + classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ] - p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal ) - - for i, logit in enumerate(p_logits): - logits[p_indices[i]] = logit - - # output projection layer - # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways - if self.classifier is not None: - logits = self.classifier(logits) # * m - # to-do: piece-wise classification, now that there's a head for text - # although again, one single monolithic head would be preferable instead...... - elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_levels ) + if decoders_indices: + decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) + decoders_logits = self.audio_decoder( decoders_logits ) + for batch_index, logit in zip( decoders_indices, decoders_logits ): + logits[batch_index] = logit + + if classifiers_indices: + classifiers_levels = [ classifier_levels[batch_index] for batch_index in classifiers_indices ] + classifiers_logits = torch.stack([ logits[batch_index] for batch_index in classifiers_indices ]) + classifiers_logits = self.classifiers( classifiers_logits, levels = classifiers_levels ) + for batch_index, logit in zip( classifiers_indices, classifiers_logits ): + logits[batch_index] = logit + else: + # output projection layer + # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways + if self.classifier is not None: + logits = self.classifier(logits) # * m + # to-do: piece-wise classification, now that there's a head for text + # although again, one single monolithic head would be preferable instead...... + elif self.classifiers is not None: + logits = self.classifiers(logits, levels = classifier_levels ) # Remove padding logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] @@ -1938,16 +1810,8 @@ class Base(nn.Module): self.loss = None self.stats = None + # compute loss if the target is given - elif self.version >= 7: - loss, stats = self.calc_loss_new( inputs=inputs, logits=logits ) - - # include any additional losses (for example: MoE router) - if output.loss is not None: - loss["aux_loss"] = output.loss - - self.loss = loss - self.stats = stats else: loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )