diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 0cfc0ed..21148f0 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -835,6 +835,7 @@ def example_usage(): # cfg.model.experimental.masking_train_p = 0.5 cfg.hyperparameters.batch_size = 1 cfg.hyperparameters.gradient_accumulation_steps = 1 + cfg.model.experimental.use_raw_text_p = 0 setup_logging() diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index edc5548..ec111a1 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1297,10 +1297,19 @@ class Base_V2(nn.Module): acc_k_hi = None if compute_hard_loss: - weight = torch.tensor( [ level_loss_factors[level] for level in loss_levels ], device=logit.device ) - nll = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index ) - nll = nll.view( batch_size, 1 if not self.resp_parallel_training else self.n_resp_levels, -1 ).mean(dim=-1) * weight - nll = nll.mean() + nll = 0 + nlls = F.cross_entropy( loss_logit, loss_target, reduction='none', ignore_index=self.ignore_index ) + + it = 0 + for seq, level in zip( loss_targets, loss_levels ): + seq_len = seq.shape[0] + start = it + it += seq_len + end = it + + nll += nlls[start:end].mean() * level_loss_factors[level] + + nll /= len( loss_targets ) if compute_acc: n_vocab = loss_logit.shape[-1]