ugh
This commit is contained in:
parent
5670fcb23f
commit
8ac03aac8a
|
@ -79,6 +79,9 @@ class AR_NAR_V2(Base_V2):
|
|||
if cfg.audio_backend == "nemo":
|
||||
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
|
||||
|
||||
# cringe
|
||||
self.audio_frames_per_second = cfg.dataset.frames_per_second
|
||||
|
||||
# CFG
|
||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
|
|
|
@ -1113,7 +1113,7 @@ class Base_V2(nn.Module):
|
|||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype )
|
||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
|
Loading…
Reference in New Issue
Block a user