From dbd34b64300cfd60ca71f9b3dbd3c23bfdbe6711 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Mar 2025 18:44:11 -0600 Subject: [PATCH] add specialized calc_loss because schizo --- vall_e/config.py | 1 + vall_e/data.py | 2 +- vall_e/models/base_v2.py | 243 +++++++++++++++++++++++++++++++++------ 3 files changed, 213 insertions(+), 33 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 05dec9e..34379b4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -289,6 +289,7 @@ class ModelExperimentalSettings: # list of floats to manually set use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself # this is a flag since I am cautious + use_streamlined_calc_loss: bool = False # explicitly request the faster pathway for loss calc, in case doing loss one by one instead of one batch is a bottleneck # these technically should be as hyperparameters # performs token dropout to compensate for errors diff --git a/vall_e/data.py b/vall_e/data.py index 81ae375..f30c70b 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1672,7 +1672,7 @@ def _create_dataloader(dataset, training): num_workers=cfg.dataset.workers, collate_fn=collate_fn, persistent_workers=cfg.dataset.workers > 1, - pin_memory=False, + pin_memory=True, worker_init_fn=_seed_worker, **kwargs, ) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 6b34689..1181590 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -82,6 +82,8 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ): return x # aims to properly encode RVQ-encoded token sequence into an embedding +# this and the decoder might not work, as i haven't gotten speech to emerge (although I might need to give it more time) +# while the FSQ version works, it might be possible to just use it instead and hope the learnable level weights make up for the FSQ-ness class ResidualAudioEncoder(nn.Module): def __init__( self, @@ -147,6 +149,7 @@ class ResidualAudioDecoder(nn.Module): return torch.stack([ self._forward(x) for x in x_i ], dim=0) # the above, but for FSQ codecs, as each level is independent from one another +# this for sure "works" as speech emerges to some extent class FiniteAudioEncoder(nn.Module): def __init__( self, @@ -332,6 +335,7 @@ class Base_V2(nn.Module): logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True + use_streamlined_calc_loss = config.experimental.use_streamlined_calc_loss if config is not None else True n_vocab = 256 n_tasks = config.tasks if config is not None else 8 @@ -421,6 +425,7 @@ class Base_V2(nn.Module): self.audio_level_loss_factors = audio_level_loss_factors self.logit_normalization = logit_normalization self.use_segmented_attention_mask = use_segmented_attention_mask + self.use_streamlined_calc_loss = use_streamlined_calc_loss self.sep = nn.Parameter(torch.randn(d_model)) @@ -907,7 +912,6 @@ class Base_V2(nn.Module): device = logits[0].device batch_size = len(logits) classifier_levels = self.get_input( inputs, "classifier_level" ) - level_loss_factor = self.audio_level_loss_factors # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -918,6 +922,8 @@ class Base_V2(nn.Module): k_lo, k_hi = 1, 20 def _calc_loss( logit, sequence, causal = True, level = None ): + level_loss_factors = self.audio_level_loss_factors + # filter tokens that exceed the vocab size sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence ) # drop if all tokens are ignored @@ -951,11 +957,14 @@ class Base_V2(nn.Module): if compute_hard_loss: reduction = 'mean' if not batched else 'none' - weight = level_loss_factor[level] if level is not None and not batched else 1 - nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight + weight = level_loss_factors[level] if level is not None and not batched else 1 + loss_func = F.cross_entropy # to-do: add mse_loss + loss_kwargs = dict(ignore_index=self.ignore_index) if loss_func == F.cross_entropy else {} + + nll = loss_func( logit, sequence, reduction=reduction, **loss_kwargs ) * weight # manually weigh each level if batched: - nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factor, device=device) + nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_loss_factors, device=device) if compute_acc: if logit.shape[0] >= k_lo: @@ -1168,6 +1177,172 @@ class Base_V2(nn.Module): return LossStats(loss, stats) + # this is a specialized loss calculation that makes a lot of assumptions to try and streamline it by doing one loss calc instead of many + def calc_loss_specialized( + self, + inputs: list, + logits, + + quant_levels: list[int] | None = None, + compute_hard_loss = True, + compute_acc = True, + ): + loss = {} + stats = {} + + device = logits[0].device + batch_size = len(logits) + classifier_levels = self.get_input( inputs, "classifier_level" ) + + # handles tasks where the prompt has task tokens injected in the middle + def prompt_input_to_token( input, quant_level ): + if isinstance(input, str): + return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) + + return input + + k_lo, k_hi = 1, 20 + level_loss_factors = self.audio_level_loss_factors + + loss_targets = [] + loss_logits = [] + loss_levels = [] + + for batch_index, batch in enumerate(inputs): + quant_level = quant_levels[batch_index] + causal = True + task_type = "tts" + dropout_mask = None + classifier_level = None + output_len = 0 + + for name, input in batch: + if name == "task": + task_type = input + elif name == "dropout_mask": + dropout_mask = input + elif name == "classifier_level": + classifier_level = input + + # autoregressive, causal + if classifier_level.startswith("AR:"): + causal = True + # nonautoregressive, parallel + elif classifier_level.startswith("NAR:"): + causal = False + + it = 0 + for name, input in batch: + token = None + ignored = False + + # non-tokened tasks + if name in non_tokened_names: + continue + # prom can either be a tensor itself or a list of tensors and strings + if name == "prom": + # expand to list if not a list + proms = [ input ] if isinstance(input, torch.Tensor) else input + # iterate over the list to inject their tokens + token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) + + if logits[batch_index].dim() < 3 and token.dim() >= 2: + token = token[..., 0] + elif name == "resp": + token = input + + # mask found, apply it + if dropout_mask is not None: + token = _dropout_codes( token, dropout_mask, self.ignore_index, swapped = True ) + # not a special input, inject as-is + else: + token = input + + if not isinstance(token, torch.Tensor): + continue + + if token.is_floating_point(): + ignored = True + + # grab range of our logits for later + seq_len = token.shape[0] + start, end = it, it+seq_len + it += seq_len + 1 # +1 to incorporate the separator + + # deduce if a name for a task is an input or output + if name != task_outputs.get(task_type, name): + continue + + output_len = seq_len + + for level in range( self.n_resp_levels ): + if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'): + continue + + logit = logits[batch_index][level][start:end] + if self.logit_normalization: + logit = logit_normalization( logit, self.logit_normalization ) + + loss_targets.append( token[:, level].long() ) + loss_logits.append( logit ) + loss_levels.append( level ) + + break + + loss_target = torch.cat( loss_targets ) + loss_logit = torch.cat( loss_logits ) + + nll = None + acc_k_lo = None + acc_k_hi = None + + if compute_hard_loss: + weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device ) + nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index ) + nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight + nll = nll.mean() + + if compute_acc: + n_vocab = loss_logit.shape[-1] + if n_vocab >= k_lo: + accuracy_metric = MulticlassAccuracy( + n_vocab, + top_k = 1, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(loss_logit.device) + acc_k_lo = accuracy_metric( loss_logit, loss_target ) + + if n_vocab >= k_hi: + accuracy_metric = MulticlassAccuracy( + n_vocab, + top_k = 20, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(loss_logit.device) + acc_k_hi = accuracy_metric( loss_logit, loss_target ) + + if nll is not None: + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"] = nll + + if acc_k_lo is not None: + acc_k_lo = acc_k_lo.mean() + if f'acc[k={k_lo}]' not in stats: + stats[f'acc[k={k_lo}]'] = [] + stats[f"acc[k={k_lo}]"] = acc_k_lo + + if acc_k_hi is not None: + acc_k_hi = acc_k_hi.mean() + if f'acc[k={k_hi}]' not in stats: + stats[f'acc[k={k_hi}]'] = [] + stats[f"acc[k={k_hi}]"] = acc_k_hi + + return LossStats(loss, stats) + def forward( self, inputs: list, @@ -1246,38 +1421,41 @@ class Base_V2(nn.Module): output_attentions = output_attentions, ) - logits = [ logit for logit in output.logits ] hidden_states = output.hidden_states - grouped_logits = {} - - for batch_index in range( batch_size ): - classifier_level = classifier_levels[batch_index] - if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"): - classifier_level = "audio" - - if classifier_level not in ["audio", "phn", "text", "len"]: - continue + if self.use_streamlined_calc_loss: + logits = head( output.logits ) + else: + logits = [ logit for logit in output.logits ] + grouped_logits = {} - if classifier_level not in grouped_logits: - grouped_logits[classifier_level] = [] - - grouped_logits[classifier_level].append(batch_index) + for batch_index in range( batch_size ): + classifier_level = classifier_levels[batch_index] + if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"): + classifier_level = "audio" - for classifier_level, decoders_indices in grouped_logits.items(): - if classifier_level == "audio": - head = self.audio_decoder - elif classifier_level == "phn": - head = self.phn_decoder - elif classifier_level == "text": - head = self.text_decoder - elif classifier_level == "len": - head = self.len_decoder + if classifier_level not in ["audio", "phn", "text", "len"]: + continue + + if classifier_level not in grouped_logits: + grouped_logits[classifier_level] = [] + + grouped_logits[classifier_level].append(batch_index) - decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) - decoders_logits = head( decoders_logits ) - for batch_index, logit in zip( decoders_indices, decoders_logits ): - logits[batch_index] = logit + for classifier_level, decoders_indices in grouped_logits.items(): + if classifier_level == "audio": + head = self.audio_decoder + elif classifier_level == "phn": + head = self.phn_decoder + elif classifier_level == "text": + head = self.text_decoder + elif classifier_level == "len": + head = self.len_decoder + + decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) + decoders_logits = head( decoders_logits ) + for batch_index, logit in zip( decoders_indices, decoders_logits ): + logits[batch_index] = logit # Remove padding logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] @@ -1291,7 +1469,8 @@ class Base_V2(nn.Module): # compute loss if the target is given else: - loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) + loss_func = self.calc_loss_specialized if self.use_streamlined_calc_loss else self.calc_loss + loss, stats = loss_func( inputs=inputs, logits=logits, quant_levels=quant_levels ) # include any additional losses (for example: MoE router) if output.loss is not None: