ugh
This commit is contained in:
parent
5670fcb23f
commit
8ac03aac8a
|
@ -79,6 +79,9 @@ class AR_NAR_V2(Base_V2):
|
||||||
if cfg.audio_backend == "nemo":
|
if cfg.audio_backend == "nemo":
|
||||||
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
|
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
|
||||||
|
|
||||||
|
# cringe
|
||||||
|
self.audio_frames_per_second = cfg.dataset.frames_per_second
|
||||||
|
|
||||||
# CFG
|
# CFG
|
||||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||||
|
|
|
@ -1113,7 +1113,7 @@ class Base_V2(nn.Module):
|
||||||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
|
|
||||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype )
|
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second
|
||||||
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user