diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 21148f0..fca872c 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -79,6 +79,9 @@ class AR_NAR_V2(Base_V2): if cfg.audio_backend == "nemo": rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ] + # cringe + self.audio_frames_per_second = cfg.dataset.frames_per_second + # CFG cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 7323424..51f1a42 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -1113,7 +1113,7 @@ class Base_V2(nn.Module): #aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 ) #loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor - aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) + aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor return LossStats(loss, stats)