This commit is contained in:
mrq 2025-03-10 21:14:56 -05:00
parent 5670fcb23f
commit 8ac03aac8a
2 changed files with 4 additions and 1 deletions

View File

@ -79,6 +79,9 @@ class AR_NAR_V2(Base_V2):
if cfg.audio_backend == "nemo":
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
# cringe
self.audio_frames_per_second = cfg.dataset.frames_per_second
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0

View File

@ -1113,7 +1113,7 @@ class Base_V2(nn.Module):
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype )
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
return LossStats(loss, stats)